Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/.circleci/config.yml +179 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/__init__.py +13 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/compat.py +229 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/config.py +202 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/defaults.py +598 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/__init__.py +18 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/build.py +397 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/catalog.py +221 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/common.py +149 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/dataset_mapper.py +149 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/README.md +9 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/__init__.py +9 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/builtin.py +220 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/builtin_meta.py +267 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/cityscapes.py +329 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/coco.py +466 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/lvis.py +209 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py +0 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/pascal_voc.py +80 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/register_coco.py +129 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/detection_utils.py +516 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/__init__.py +10 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/distributed_sampler.py +199 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/grouped_batch_sampler.py +47 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/__init__.py +6 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/transform.py +241 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/transform_gen.py +534 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/__init__.py +12 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/cityscapes_evaluation.py +187 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/coco_evaluation.py +512 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/evaluator.py +196 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/lvis_evaluation.py +350 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/panoptic_evaluation.py +167 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/pascal_voc_evaluation.py +294 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/rotated_coco_evaluation.py +204 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/sem_seg_evaluation.py +168 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/testing.py +78 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/README.md +10 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/__init__.py +5 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/api.py +277 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/c10.py +503 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_export.py +204 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_inference.py +136 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_modeling.py +493 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/patcher.py +153 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/shared.py +1034 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h +35 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp +39 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu +130 -0
- Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h +363 -0
Leffa/preprocess/humanparsing/mhp_extension/detectron2/.circleci/config.yml
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python CircleCI 2.0 configuration file
|
| 2 |
+
#
|
| 3 |
+
# Check https://circleci.com/docs/2.0/language-python/ for more details
|
| 4 |
+
#
|
| 5 |
+
version: 2
|
| 6 |
+
|
| 7 |
+
# -------------------------------------------------------------------------------------
|
| 8 |
+
# Environments to run the jobs in
|
| 9 |
+
# -------------------------------------------------------------------------------------
|
| 10 |
+
cpu: &cpu
|
| 11 |
+
docker:
|
| 12 |
+
- image: circleci/python:3.6.8-stretch
|
| 13 |
+
resource_class: medium
|
| 14 |
+
|
| 15 |
+
gpu: &gpu
|
| 16 |
+
machine:
|
| 17 |
+
image: ubuntu-1604:201903-01
|
| 18 |
+
docker_layer_caching: true
|
| 19 |
+
resource_class: gpu.small
|
| 20 |
+
|
| 21 |
+
# -------------------------------------------------------------------------------------
|
| 22 |
+
# Re-usable commands
|
| 23 |
+
# -------------------------------------------------------------------------------------
|
| 24 |
+
install_python: &install_python
|
| 25 |
+
- run:
|
| 26 |
+
name: Install Python
|
| 27 |
+
working_directory: ~/
|
| 28 |
+
command: |
|
| 29 |
+
pyenv install 3.6.1
|
| 30 |
+
pyenv global 3.6.1
|
| 31 |
+
|
| 32 |
+
setup_venv: &setup_venv
|
| 33 |
+
- run:
|
| 34 |
+
name: Setup Virtual Env
|
| 35 |
+
working_directory: ~/
|
| 36 |
+
command: |
|
| 37 |
+
python -m venv ~/venv
|
| 38 |
+
echo ". ~/venv/bin/activate" >> $BASH_ENV
|
| 39 |
+
. ~/venv/bin/activate
|
| 40 |
+
python --version
|
| 41 |
+
which python
|
| 42 |
+
which pip
|
| 43 |
+
pip install --upgrade pip
|
| 44 |
+
|
| 45 |
+
install_dep: &install_dep
|
| 46 |
+
- run:
|
| 47 |
+
name: Install Dependencies
|
| 48 |
+
command: |
|
| 49 |
+
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
|
| 50 |
+
pip install --progress-bar off cython opencv-python
|
| 51 |
+
pip install --progress-bar off 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
|
| 52 |
+
pip install --progress-bar off torch torchvision
|
| 53 |
+
|
| 54 |
+
install_detectron2: &install_detectron2
|
| 55 |
+
- run:
|
| 56 |
+
name: Install Detectron2
|
| 57 |
+
command: |
|
| 58 |
+
gcc --version
|
| 59 |
+
pip install -U --progress-bar off -e .[dev]
|
| 60 |
+
python -m detectron2.utils.collect_env
|
| 61 |
+
|
| 62 |
+
install_nvidia_driver: &install_nvidia_driver
|
| 63 |
+
- run:
|
| 64 |
+
name: Install nvidia driver
|
| 65 |
+
working_directory: ~/
|
| 66 |
+
command: |
|
| 67 |
+
wget -q 'https://s3.amazonaws.com/ossci-linux/nvidia_driver/NVIDIA-Linux-x86_64-430.40.run'
|
| 68 |
+
sudo /bin/bash ./NVIDIA-Linux-x86_64-430.40.run -s --no-drm
|
| 69 |
+
nvidia-smi
|
| 70 |
+
|
| 71 |
+
run_unittests: &run_unittests
|
| 72 |
+
- run:
|
| 73 |
+
name: Run Unit Tests
|
| 74 |
+
command: |
|
| 75 |
+
python -m unittest discover -v -s tests
|
| 76 |
+
|
| 77 |
+
# -------------------------------------------------------------------------------------
|
| 78 |
+
# Jobs to run
|
| 79 |
+
# -------------------------------------------------------------------------------------
|
| 80 |
+
jobs:
|
| 81 |
+
cpu_tests:
|
| 82 |
+
<<: *cpu
|
| 83 |
+
|
| 84 |
+
working_directory: ~/detectron2
|
| 85 |
+
|
| 86 |
+
steps:
|
| 87 |
+
- checkout
|
| 88 |
+
- <<: *setup_venv
|
| 89 |
+
|
| 90 |
+
# Cache the venv directory that contains dependencies
|
| 91 |
+
- restore_cache:
|
| 92 |
+
keys:
|
| 93 |
+
- cache-key-{{ .Branch }}-ID-20200425
|
| 94 |
+
|
| 95 |
+
- <<: *install_dep
|
| 96 |
+
|
| 97 |
+
- save_cache:
|
| 98 |
+
paths:
|
| 99 |
+
- ~/venv
|
| 100 |
+
key: cache-key-{{ .Branch }}-ID-20200425
|
| 101 |
+
|
| 102 |
+
- <<: *install_detectron2
|
| 103 |
+
|
| 104 |
+
- run:
|
| 105 |
+
name: isort
|
| 106 |
+
command: |
|
| 107 |
+
isort -c -sp .
|
| 108 |
+
- run:
|
| 109 |
+
name: black
|
| 110 |
+
command: |
|
| 111 |
+
black --check -l 100 .
|
| 112 |
+
- run:
|
| 113 |
+
name: flake8
|
| 114 |
+
command: |
|
| 115 |
+
flake8 .
|
| 116 |
+
|
| 117 |
+
- <<: *run_unittests
|
| 118 |
+
|
| 119 |
+
gpu_tests:
|
| 120 |
+
<<: *gpu
|
| 121 |
+
|
| 122 |
+
working_directory: ~/detectron2
|
| 123 |
+
|
| 124 |
+
steps:
|
| 125 |
+
- checkout
|
| 126 |
+
- <<: *install_nvidia_driver
|
| 127 |
+
|
| 128 |
+
- run:
|
| 129 |
+
name: Install nvidia-docker
|
| 130 |
+
working_directory: ~/
|
| 131 |
+
command: |
|
| 132 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
|
| 133 |
+
distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
|
| 134 |
+
curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \
|
| 135 |
+
sudo tee /etc/apt/sources.list.d/nvidia-docker.list
|
| 136 |
+
sudo apt-get update && sudo apt-get install -y nvidia-docker2
|
| 137 |
+
# reload the docker daemon configuration
|
| 138 |
+
sudo pkill -SIGHUP dockerd
|
| 139 |
+
|
| 140 |
+
- run:
|
| 141 |
+
name: Launch docker
|
| 142 |
+
working_directory: ~/detectron2/docker
|
| 143 |
+
command: |
|
| 144 |
+
nvidia-docker build -t detectron2:v0 -f Dockerfile-circleci .
|
| 145 |
+
nvidia-docker run -itd --name d2 detectron2:v0
|
| 146 |
+
docker exec -it d2 nvidia-smi
|
| 147 |
+
|
| 148 |
+
- run:
|
| 149 |
+
name: Build Detectron2
|
| 150 |
+
command: |
|
| 151 |
+
docker exec -it d2 pip install 'git+https://github.com/facebookresearch/fvcore'
|
| 152 |
+
docker cp ~/detectron2 d2:/detectron2
|
| 153 |
+
# This will build d2 for the target GPU arch only
|
| 154 |
+
docker exec -it d2 pip install -e /detectron2
|
| 155 |
+
docker exec -it d2 python3 -m detectron2.utils.collect_env
|
| 156 |
+
docker exec -it d2 python3 -c 'import torch; assert(torch.cuda.is_available())'
|
| 157 |
+
|
| 158 |
+
- run:
|
| 159 |
+
name: Run Unit Tests
|
| 160 |
+
command: |
|
| 161 |
+
docker exec -e CIRCLECI=true -it d2 python3 -m unittest discover -v -s /detectron2/tests
|
| 162 |
+
|
| 163 |
+
workflows:
|
| 164 |
+
version: 2
|
| 165 |
+
regular_test:
|
| 166 |
+
jobs:
|
| 167 |
+
- cpu_tests
|
| 168 |
+
- gpu_tests
|
| 169 |
+
|
| 170 |
+
#nightly_test:
|
| 171 |
+
#jobs:
|
| 172 |
+
#- gpu_tests
|
| 173 |
+
#triggers:
|
| 174 |
+
#- schedule:
|
| 175 |
+
#cron: "0 0 * * *"
|
| 176 |
+
#filters:
|
| 177 |
+
#branches:
|
| 178 |
+
#only:
|
| 179 |
+
#- master
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .compat import downgrade_config, upgrade_config
|
| 3 |
+
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"CfgNode",
|
| 7 |
+
"get_cfg",
|
| 8 |
+
"global_cfg",
|
| 9 |
+
"set_global_cfg",
|
| 10 |
+
"downgrade_config",
|
| 11 |
+
"upgrade_config",
|
| 12 |
+
"configurable",
|
| 13 |
+
]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/compat.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Backward compatibility of configs.
|
| 4 |
+
|
| 5 |
+
Instructions to bump version:
|
| 6 |
+
+ It's not needed to bump version if new keys are added.
|
| 7 |
+
It's only needed when backward-incompatible changes happen
|
| 8 |
+
(i.e., some existing keys disappear, or the meaning of a key changes)
|
| 9 |
+
+ To bump version, do the following:
|
| 10 |
+
1. Increment _C.VERSION in defaults.py
|
| 11 |
+
2. Add a converter in this file.
|
| 12 |
+
|
| 13 |
+
Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X,
|
| 14 |
+
and a function "downgrade" which in-place downgrades config from X to X-1
|
| 15 |
+
|
| 16 |
+
In each function, VERSION is left unchanged.
|
| 17 |
+
|
| 18 |
+
Each converter assumes that its input has the relevant keys
|
| 19 |
+
(i.e., the input is not a partial config).
|
| 20 |
+
3. Run the tests (test_config.py) to make sure the upgrade & downgrade
|
| 21 |
+
functions are consistent.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
from typing import List, Optional, Tuple
|
| 26 |
+
|
| 27 |
+
from .config import CfgNode as CN
|
| 28 |
+
from .defaults import _C
|
| 29 |
+
|
| 30 |
+
__all__ = ["upgrade_config", "downgrade_config"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN:
|
| 34 |
+
"""
|
| 35 |
+
Upgrade a config from its current version to a newer version.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
cfg (CfgNode):
|
| 39 |
+
to_version (int): defaults to the latest version.
|
| 40 |
+
"""
|
| 41 |
+
cfg = cfg.clone()
|
| 42 |
+
if to_version is None:
|
| 43 |
+
to_version = _C.VERSION
|
| 44 |
+
|
| 45 |
+
assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format(
|
| 46 |
+
cfg.VERSION, to_version
|
| 47 |
+
)
|
| 48 |
+
for k in range(cfg.VERSION, to_version):
|
| 49 |
+
converter = globals()["ConverterV" + str(k + 1)]
|
| 50 |
+
converter.upgrade(cfg)
|
| 51 |
+
cfg.VERSION = k + 1
|
| 52 |
+
return cfg
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def downgrade_config(cfg: CN, to_version: int) -> CN:
|
| 56 |
+
"""
|
| 57 |
+
Downgrade a config from its current version to an older version.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
cfg (CfgNode):
|
| 61 |
+
to_version (int):
|
| 62 |
+
|
| 63 |
+
Note:
|
| 64 |
+
A general downgrade of arbitrary configs is not always possible due to the
|
| 65 |
+
different functionalities in different versions.
|
| 66 |
+
The purpose of downgrade is only to recover the defaults in old versions,
|
| 67 |
+
allowing it to load an old partial yaml config.
|
| 68 |
+
Therefore, the implementation only needs to fill in the default values
|
| 69 |
+
in the old version when a general downgrade is not possible.
|
| 70 |
+
"""
|
| 71 |
+
cfg = cfg.clone()
|
| 72 |
+
assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
|
| 73 |
+
cfg.VERSION, to_version
|
| 74 |
+
)
|
| 75 |
+
for k in range(cfg.VERSION, to_version, -1):
|
| 76 |
+
converter = globals()["ConverterV" + str(k)]
|
| 77 |
+
converter.downgrade(cfg)
|
| 78 |
+
cfg.VERSION = k - 1
|
| 79 |
+
return cfg
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def guess_version(cfg: CN, filename: str) -> int:
|
| 83 |
+
"""
|
| 84 |
+
Guess the version of a partial config where the VERSION field is not specified.
|
| 85 |
+
Returns the version, or the latest if cannot make a guess.
|
| 86 |
+
|
| 87 |
+
This makes it easier for users to migrate.
|
| 88 |
+
"""
|
| 89 |
+
logger = logging.getLogger(__name__)
|
| 90 |
+
|
| 91 |
+
def _has(name: str) -> bool:
|
| 92 |
+
cur = cfg
|
| 93 |
+
for n in name.split("."):
|
| 94 |
+
if n not in cur:
|
| 95 |
+
return False
|
| 96 |
+
cur = cur[n]
|
| 97 |
+
return True
|
| 98 |
+
|
| 99 |
+
# Most users' partial configs have "MODEL.WEIGHT", so guess on it
|
| 100 |
+
ret = None
|
| 101 |
+
if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"):
|
| 102 |
+
ret = 1
|
| 103 |
+
|
| 104 |
+
if ret is not None:
|
| 105 |
+
logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret))
|
| 106 |
+
else:
|
| 107 |
+
ret = _C.VERSION
|
| 108 |
+
logger.warning(
|
| 109 |
+
"Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format(
|
| 110 |
+
filename, ret
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
return ret
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _rename(cfg: CN, old: str, new: str) -> None:
|
| 117 |
+
old_keys = old.split(".")
|
| 118 |
+
new_keys = new.split(".")
|
| 119 |
+
|
| 120 |
+
def _set(key_seq: List[str], val: str) -> None:
|
| 121 |
+
cur = cfg
|
| 122 |
+
for k in key_seq[:-1]:
|
| 123 |
+
if k not in cur:
|
| 124 |
+
cur[k] = CN()
|
| 125 |
+
cur = cur[k]
|
| 126 |
+
cur[key_seq[-1]] = val
|
| 127 |
+
|
| 128 |
+
def _get(key_seq: List[str]) -> CN:
|
| 129 |
+
cur = cfg
|
| 130 |
+
for k in key_seq:
|
| 131 |
+
cur = cur[k]
|
| 132 |
+
return cur
|
| 133 |
+
|
| 134 |
+
def _del(key_seq: List[str]) -> None:
|
| 135 |
+
cur = cfg
|
| 136 |
+
for k in key_seq[:-1]:
|
| 137 |
+
cur = cur[k]
|
| 138 |
+
del cur[key_seq[-1]]
|
| 139 |
+
if len(cur) == 0 and len(key_seq) > 1:
|
| 140 |
+
_del(key_seq[:-1])
|
| 141 |
+
|
| 142 |
+
_set(new_keys, _get(old_keys))
|
| 143 |
+
_del(old_keys)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class _RenameConverter:
|
| 147 |
+
"""
|
| 148 |
+
A converter that handles simple rename.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def upgrade(cls, cfg: CN) -> None:
|
| 155 |
+
for old, new in cls.RENAME:
|
| 156 |
+
_rename(cfg, old, new)
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def downgrade(cls, cfg: CN) -> None:
|
| 160 |
+
for old, new in cls.RENAME[::-1]:
|
| 161 |
+
_rename(cfg, new, old)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class ConverterV1(_RenameConverter):
|
| 165 |
+
RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConverterV2(_RenameConverter):
|
| 169 |
+
"""
|
| 170 |
+
A large bulk of rename, before public release.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
RENAME = [
|
| 174 |
+
("MODEL.WEIGHT", "MODEL.WEIGHTS"),
|
| 175 |
+
("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
|
| 176 |
+
("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
|
| 177 |
+
("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
|
| 178 |
+
("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
|
| 179 |
+
(
|
| 180 |
+
"MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
|
| 181 |
+
"MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
|
| 182 |
+
),
|
| 183 |
+
(
|
| 184 |
+
"MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
|
| 185 |
+
"MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
|
| 186 |
+
),
|
| 187 |
+
(
|
| 188 |
+
"MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
|
| 189 |
+
"MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
|
| 190 |
+
),
|
| 191 |
+
("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
|
| 192 |
+
("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
|
| 193 |
+
("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
|
| 194 |
+
("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
|
| 195 |
+
("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
|
| 196 |
+
("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
|
| 197 |
+
("TEST.AUG_ON", "TEST.AUG.ENABLED"),
|
| 198 |
+
("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
|
| 199 |
+
("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
|
| 200 |
+
("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def upgrade(cls, cfg: CN) -> None:
|
| 205 |
+
super().upgrade(cfg)
|
| 206 |
+
|
| 207 |
+
if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
|
| 208 |
+
_rename(
|
| 209 |
+
cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
|
| 210 |
+
)
|
| 211 |
+
_rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
|
| 212 |
+
del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
|
| 213 |
+
del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
|
| 214 |
+
else:
|
| 215 |
+
_rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
|
| 216 |
+
_rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
|
| 217 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
|
| 218 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
|
| 219 |
+
del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
|
| 220 |
+
|
| 221 |
+
@classmethod
|
| 222 |
+
def downgrade(cls, cfg: CN) -> None:
|
| 223 |
+
super().downgrade(cfg)
|
| 224 |
+
|
| 225 |
+
_rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
|
| 226 |
+
_rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
|
| 227 |
+
cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
|
| 228 |
+
cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
|
| 229 |
+
cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/config.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import inspect
|
| 6 |
+
import logging
|
| 7 |
+
from fvcore.common.config import CfgNode as _CfgNode
|
| 8 |
+
from fvcore.common.file_io import PathManager
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CfgNode(_CfgNode):
|
| 12 |
+
"""
|
| 13 |
+
The same as `fvcore.common.config.CfgNode`, but different in:
|
| 14 |
+
|
| 15 |
+
1. Use unsafe yaml loading by default.
|
| 16 |
+
Note that this may lead to arbitrary code execution: you must not
|
| 17 |
+
load a config file from untrusted sources before manually inspecting
|
| 18 |
+
the content of the file.
|
| 19 |
+
2. Support config versioning.
|
| 20 |
+
When attempting to merge an old config, it will convert the old config automatically.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# Note that the default value of allow_unsafe is changed to True
|
| 24 |
+
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
|
| 25 |
+
assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
|
| 26 |
+
loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
|
| 27 |
+
loaded_cfg = type(self)(loaded_cfg)
|
| 28 |
+
|
| 29 |
+
# defaults.py needs to import CfgNode
|
| 30 |
+
from .defaults import _C
|
| 31 |
+
|
| 32 |
+
latest_ver = _C.VERSION
|
| 33 |
+
assert (
|
| 34 |
+
latest_ver == self.VERSION
|
| 35 |
+
), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
loaded_ver = loaded_cfg.get("VERSION", None)
|
| 40 |
+
if loaded_ver is None:
|
| 41 |
+
from .compat import guess_version
|
| 42 |
+
|
| 43 |
+
loaded_ver = guess_version(loaded_cfg, cfg_filename)
|
| 44 |
+
assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
|
| 45 |
+
loaded_ver, self.VERSION
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if loaded_ver == self.VERSION:
|
| 49 |
+
self.merge_from_other_cfg(loaded_cfg)
|
| 50 |
+
else:
|
| 51 |
+
# compat.py needs to import CfgNode
|
| 52 |
+
from .compat import upgrade_config, downgrade_config
|
| 53 |
+
|
| 54 |
+
logger.warning(
|
| 55 |
+
"Loading an old v{} config file '{}' by automatically upgrading to v{}. "
|
| 56 |
+
"See docs/CHANGELOG.md for instructions to update your files.".format(
|
| 57 |
+
loaded_ver, cfg_filename, self.VERSION
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
# To convert, first obtain a full config at an old version
|
| 61 |
+
old_self = downgrade_config(self, to_version=loaded_ver)
|
| 62 |
+
old_self.merge_from_other_cfg(loaded_cfg)
|
| 63 |
+
new_config = upgrade_config(old_self)
|
| 64 |
+
self.clear()
|
| 65 |
+
self.update(new_config)
|
| 66 |
+
|
| 67 |
+
def dump(self, *args, **kwargs):
|
| 68 |
+
"""
|
| 69 |
+
Returns:
|
| 70 |
+
str: a yaml string representation of the config
|
| 71 |
+
"""
|
| 72 |
+
# to make it show up in docs
|
| 73 |
+
return super().dump(*args, **kwargs)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
global_cfg = CfgNode()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_cfg() -> CfgNode:
|
| 80 |
+
"""
|
| 81 |
+
Get a copy of the default config.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
a detectron2 CfgNode instance.
|
| 85 |
+
"""
|
| 86 |
+
from .defaults import _C
|
| 87 |
+
|
| 88 |
+
return _C.clone()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def set_global_cfg(cfg: CfgNode) -> None:
|
| 92 |
+
"""
|
| 93 |
+
Let the global config point to the given cfg.
|
| 94 |
+
|
| 95 |
+
Assume that the given "cfg" has the key "KEY", after calling
|
| 96 |
+
`set_global_cfg(cfg)`, the key can be accessed by:
|
| 97 |
+
|
| 98 |
+
.. code-block:: python
|
| 99 |
+
|
| 100 |
+
from detectron2.config import global_cfg
|
| 101 |
+
print(global_cfg.KEY)
|
| 102 |
+
|
| 103 |
+
By using a hacky global config, you can access these configs anywhere,
|
| 104 |
+
without having to pass the config object or the values deep into the code.
|
| 105 |
+
This is a hacky feature introduced for quick prototyping / research exploration.
|
| 106 |
+
"""
|
| 107 |
+
global global_cfg
|
| 108 |
+
global_cfg.clear()
|
| 109 |
+
global_cfg.update(cfg)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def configurable(init_func):
|
| 113 |
+
"""
|
| 114 |
+
Decorate a class's __init__ method so that it can be called with a CfgNode
|
| 115 |
+
object using the class's from_config classmethod.
|
| 116 |
+
|
| 117 |
+
Examples:
|
| 118 |
+
|
| 119 |
+
.. code-block:: python
|
| 120 |
+
|
| 121 |
+
class A:
|
| 122 |
+
@configurable
|
| 123 |
+
def __init__(self, a, b=2, c=3):
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def from_config(cls, cfg):
|
| 128 |
+
# Returns kwargs to be passed to __init__
|
| 129 |
+
return {"a": cfg.A, "b": cfg.B}
|
| 130 |
+
|
| 131 |
+
a1 = A(a=1, b=2) # regular construction
|
| 132 |
+
a2 = A(cfg) # construct with a cfg
|
| 133 |
+
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
| 134 |
+
"""
|
| 135 |
+
assert init_func.__name__ == "__init__", "@configurable should only be used for __init__!"
|
| 136 |
+
if init_func.__module__.startswith("detectron2."):
|
| 137 |
+
assert (
|
| 138 |
+
init_func.__doc__ is not None and "experimental" in init_func.__doc__
|
| 139 |
+
), f"configurable {init_func} should be marked experimental"
|
| 140 |
+
|
| 141 |
+
@functools.wraps(init_func)
|
| 142 |
+
def wrapped(self, *args, **kwargs):
|
| 143 |
+
try:
|
| 144 |
+
from_config_func = type(self).from_config
|
| 145 |
+
except AttributeError:
|
| 146 |
+
raise AttributeError("Class with @configurable must have a 'from_config' classmethod.")
|
| 147 |
+
if not inspect.ismethod(from_config_func):
|
| 148 |
+
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
| 149 |
+
|
| 150 |
+
if _called_with_cfg(*args, **kwargs):
|
| 151 |
+
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
| 152 |
+
init_func(self, **explicit_args)
|
| 153 |
+
else:
|
| 154 |
+
init_func(self, *args, **kwargs)
|
| 155 |
+
|
| 156 |
+
return wrapped
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _get_args_from_config(from_config_func, *args, **kwargs):
|
| 160 |
+
"""
|
| 161 |
+
Use `from_config` to obtain explicit arguments.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
dict: arguments to be used for cls.__init__
|
| 165 |
+
"""
|
| 166 |
+
signature = inspect.signature(from_config_func)
|
| 167 |
+
if list(signature.parameters.keys())[0] != "cfg":
|
| 168 |
+
raise TypeError(
|
| 169 |
+
f"{from_config_func.__self__}.from_config must take 'cfg' as the first argument!"
|
| 170 |
+
)
|
| 171 |
+
support_var_arg = any(
|
| 172 |
+
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
| 173 |
+
for param in signature.parameters.values()
|
| 174 |
+
)
|
| 175 |
+
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
| 176 |
+
ret = from_config_func(*args, **kwargs)
|
| 177 |
+
else:
|
| 178 |
+
# forward supported arguments to from_config
|
| 179 |
+
supported_arg_names = set(signature.parameters.keys())
|
| 180 |
+
extra_kwargs = {}
|
| 181 |
+
for name in list(kwargs.keys()):
|
| 182 |
+
if name not in supported_arg_names:
|
| 183 |
+
extra_kwargs[name] = kwargs.pop(name)
|
| 184 |
+
ret = from_config_func(*args, **kwargs)
|
| 185 |
+
# forward the other arguments to __init__
|
| 186 |
+
ret.update(extra_kwargs)
|
| 187 |
+
return ret
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _called_with_cfg(*args, **kwargs):
|
| 191 |
+
"""
|
| 192 |
+
Returns:
|
| 193 |
+
bool: whether the arguments contain CfgNode and should be considered
|
| 194 |
+
forwarded to from_config.
|
| 195 |
+
"""
|
| 196 |
+
if len(args) and isinstance(args[0], _CfgNode):
|
| 197 |
+
return True
|
| 198 |
+
if isinstance(kwargs.pop("cfg", None), _CfgNode):
|
| 199 |
+
return True
|
| 200 |
+
# `from_config`'s first argument is forced to be "cfg".
|
| 201 |
+
# So the above check covers all cases.
|
| 202 |
+
return False
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/config/defaults.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .config import CfgNode as CN
|
| 3 |
+
|
| 4 |
+
# -----------------------------------------------------------------------------
|
| 5 |
+
# Convention about Training / Test specific parameters
|
| 6 |
+
# -----------------------------------------------------------------------------
|
| 7 |
+
# Whenever an argument can be either used for training or for testing, the
|
| 8 |
+
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
| 9 |
+
# or _TEST for a test-specific parameter.
|
| 10 |
+
# For example, the number of images during training will be
|
| 11 |
+
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
|
| 12 |
+
# IMAGES_PER_BATCH_TEST
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# Config definition
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
_C = CN()
|
| 19 |
+
|
| 20 |
+
# The version number, to upgrade from old configs to new ones if any
|
| 21 |
+
# changes happen. It's recommended to keep a VERSION in your config file.
|
| 22 |
+
_C.VERSION = 2
|
| 23 |
+
|
| 24 |
+
_C.MODEL = CN()
|
| 25 |
+
_C.MODEL.LOAD_PROPOSALS = False
|
| 26 |
+
_C.MODEL.MASK_ON = False
|
| 27 |
+
_C.MODEL.KEYPOINT_ON = False
|
| 28 |
+
_C.MODEL.DEVICE = "cuda"
|
| 29 |
+
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
|
| 30 |
+
|
| 31 |
+
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
|
| 32 |
+
# to be loaded to the model. You can find available models in the model zoo.
|
| 33 |
+
_C.MODEL.WEIGHTS = ""
|
| 34 |
+
|
| 35 |
+
# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
|
| 36 |
+
# To train on images of different number of channels, just set different mean & std.
|
| 37 |
+
# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
|
| 38 |
+
_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
|
| 39 |
+
# When using pre-trained models in Detectron1 or any MSRA models,
|
| 40 |
+
# std has been absorbed into its conv1 weights, so the std needs to be set 1.
|
| 41 |
+
# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
|
| 42 |
+
_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# -----------------------------------------------------------------------------
|
| 46 |
+
# INPUT
|
| 47 |
+
# -----------------------------------------------------------------------------
|
| 48 |
+
_C.INPUT = CN()
|
| 49 |
+
# Size of the smallest side of the image during training
|
| 50 |
+
_C.INPUT.MIN_SIZE_TRAIN = (800,)
|
| 51 |
+
# Sample size of smallest side by choice or random selection from range give by
|
| 52 |
+
# INPUT.MIN_SIZE_TRAIN
|
| 53 |
+
_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
|
| 54 |
+
# Maximum size of the side of the image during training
|
| 55 |
+
_C.INPUT.MAX_SIZE_TRAIN = 1333
|
| 56 |
+
# Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
|
| 57 |
+
_C.INPUT.MIN_SIZE_TEST = 800
|
| 58 |
+
# Maximum size of the side of the image during testing
|
| 59 |
+
_C.INPUT.MAX_SIZE_TEST = 1333
|
| 60 |
+
|
| 61 |
+
# `True` if cropping is used for data augmentation during training
|
| 62 |
+
_C.INPUT.CROP = CN({"ENABLED": False})
|
| 63 |
+
# Cropping type:
|
| 64 |
+
# - "relative" crop (H * CROP.SIZE[0], W * CROP.SIZE[1]) part of an input of size (H, W)
|
| 65 |
+
# - "relative_range" uniformly sample relative crop size from between [CROP.SIZE[0], [CROP.SIZE[1]].
|
| 66 |
+
# and [1, 1] and use it as in "relative" scenario.
|
| 67 |
+
# - "absolute" crop part of an input with absolute size: (CROP.SIZE[0], CROP.SIZE[1]).
|
| 68 |
+
_C.INPUT.CROP.TYPE = "relative_range"
|
| 69 |
+
# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
|
| 70 |
+
# pixels if CROP.TYPE is "absolute"
|
| 71 |
+
_C.INPUT.CROP.SIZE = [0.9, 0.9]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Whether the model needs RGB, YUV, HSV etc.
|
| 75 |
+
# Should be one of the modes defined here, as we use PIL to read the image:
|
| 76 |
+
# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
|
| 77 |
+
# with BGR being the one exception. One can set image format to BGR, we will
|
| 78 |
+
# internally use RGB for conversion and flip the channels over
|
| 79 |
+
_C.INPUT.FORMAT = "BGR"
|
| 80 |
+
# The ground truth mask format that the model will use.
|
| 81 |
+
# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
|
| 82 |
+
_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------------------------------------------------------
|
| 86 |
+
# Dataset
|
| 87 |
+
# -----------------------------------------------------------------------------
|
| 88 |
+
_C.DATASETS = CN()
|
| 89 |
+
# List of the dataset names for training. Must be registered in DatasetCatalog
|
| 90 |
+
_C.DATASETS.TRAIN = ()
|
| 91 |
+
# List of the pre-computed proposal files for training, which must be consistent
|
| 92 |
+
# with data listed in DATASETS.TRAIN.
|
| 93 |
+
_C.DATASETS.PROPOSAL_FILES_TRAIN = ()
|
| 94 |
+
# Number of top scoring precomputed proposals to keep for training
|
| 95 |
+
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
|
| 96 |
+
# List of the dataset names for testing. Must be registered in DatasetCatalog
|
| 97 |
+
_C.DATASETS.TEST = ()
|
| 98 |
+
# List of the pre-computed proposal files for test, which must be consistent
|
| 99 |
+
# with data listed in DATASETS.TEST.
|
| 100 |
+
_C.DATASETS.PROPOSAL_FILES_TEST = ()
|
| 101 |
+
# Number of top scoring precomputed proposals to keep for test
|
| 102 |
+
_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
|
| 103 |
+
|
| 104 |
+
# -----------------------------------------------------------------------------
|
| 105 |
+
# DataLoader
|
| 106 |
+
# -----------------------------------------------------------------------------
|
| 107 |
+
_C.DATALOADER = CN()
|
| 108 |
+
# Number of data loading threads
|
| 109 |
+
_C.DATALOADER.NUM_WORKERS = 4
|
| 110 |
+
# If True, each batch should contain only images for which the aspect ratio
|
| 111 |
+
# is compatible. This groups portrait images together, and landscape images
|
| 112 |
+
# are not batched with portrait images.
|
| 113 |
+
_C.DATALOADER.ASPECT_RATIO_GROUPING = True
|
| 114 |
+
# Options: TrainingSampler, RepeatFactorTrainingSampler
|
| 115 |
+
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
|
| 116 |
+
# Repeat threshold for RepeatFactorTrainingSampler
|
| 117 |
+
_C.DATALOADER.REPEAT_THRESHOLD = 0.0
|
| 118 |
+
# if True, the dataloader will filter out images that have no associated
|
| 119 |
+
# annotations at train time.
|
| 120 |
+
_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------- #
|
| 123 |
+
# Backbone options
|
| 124 |
+
# ---------------------------------------------------------------------------- #
|
| 125 |
+
_C.MODEL.BACKBONE = CN()
|
| 126 |
+
|
| 127 |
+
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
| 128 |
+
# Freeze the first several stages so they are not trained.
|
| 129 |
+
# There are 5 stages in ResNet. The first is a convolution, and the following
|
| 130 |
+
# stages are each group of residual blocks.
|
| 131 |
+
_C.MODEL.BACKBONE.FREEZE_AT = 2
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ---------------------------------------------------------------------------- #
|
| 135 |
+
# FPN options
|
| 136 |
+
# ---------------------------------------------------------------------------- #
|
| 137 |
+
_C.MODEL.FPN = CN()
|
| 138 |
+
# Names of the input feature maps to be used by FPN
|
| 139 |
+
# They must have contiguous power of 2 strides
|
| 140 |
+
# e.g., ["res2", "res3", "res4", "res5"]
|
| 141 |
+
_C.MODEL.FPN.IN_FEATURES = []
|
| 142 |
+
_C.MODEL.FPN.OUT_CHANNELS = 256
|
| 143 |
+
|
| 144 |
+
# Options: "" (no norm), "GN"
|
| 145 |
+
_C.MODEL.FPN.NORM = ""
|
| 146 |
+
|
| 147 |
+
# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
|
| 148 |
+
_C.MODEL.FPN.FUSE_TYPE = "sum"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------- #
|
| 152 |
+
# Proposal generator options
|
| 153 |
+
# ---------------------------------------------------------------------------- #
|
| 154 |
+
_C.MODEL.PROPOSAL_GENERATOR = CN()
|
| 155 |
+
# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
|
| 156 |
+
_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
|
| 157 |
+
# Proposal height and width both need to be greater than MIN_SIZE
|
| 158 |
+
# (a the scale used during training or inference)
|
| 159 |
+
_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------- #
|
| 163 |
+
# Anchor generator options
|
| 164 |
+
# ---------------------------------------------------------------------------- #
|
| 165 |
+
_C.MODEL.ANCHOR_GENERATOR = CN()
|
| 166 |
+
# The generator can be any name in the ANCHOR_GENERATOR registry
|
| 167 |
+
_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
|
| 168 |
+
# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
|
| 169 |
+
# Format: list[list[float]]. SIZES[i] specifies the list of sizes
|
| 170 |
+
# to use for IN_FEATURES[i]; len(SIZES) == len(IN_FEATURES) must be true,
|
| 171 |
+
# or len(SIZES) == 1 is true and size list SIZES[0] is used for all
|
| 172 |
+
# IN_FEATURES.
|
| 173 |
+
_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
|
| 174 |
+
# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
|
| 175 |
+
# ratios are generated by an anchor generator.
|
| 176 |
+
# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
|
| 177 |
+
# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
|
| 178 |
+
# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
|
| 179 |
+
# for all IN_FEATURES.
|
| 180 |
+
_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
|
| 181 |
+
# Anchor angles.
|
| 182 |
+
# list[list[float]], the angle in degrees, for each input feature map.
|
| 183 |
+
# ANGLES[i] specifies the list of angles for IN_FEATURES[i].
|
| 184 |
+
_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
|
| 185 |
+
# Relative offset between the center of the first anchor and the top-left corner of the image
|
| 186 |
+
# Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
|
| 187 |
+
# The value is not expected to affect model accuracy.
|
| 188 |
+
_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
|
| 189 |
+
|
| 190 |
+
# ---------------------------------------------------------------------------- #
|
| 191 |
+
# RPN options
|
| 192 |
+
# ---------------------------------------------------------------------------- #
|
| 193 |
+
_C.MODEL.RPN = CN()
|
| 194 |
+
_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
|
| 195 |
+
|
| 196 |
+
# Names of the input feature maps to be used by RPN
|
| 197 |
+
# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
|
| 198 |
+
_C.MODEL.RPN.IN_FEATURES = ["res4"]
|
| 199 |
+
# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
|
| 200 |
+
# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
|
| 201 |
+
_C.MODEL.RPN.BOUNDARY_THRESH = -1
|
| 202 |
+
# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
|
| 203 |
+
# Minimum overlap required between an anchor and ground-truth box for the
|
| 204 |
+
# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
|
| 205 |
+
# ==> positive RPN example: 1)
|
| 206 |
+
# Maximum overlap allowed between an anchor and ground-truth box for the
|
| 207 |
+
# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
|
| 208 |
+
# ==> negative RPN example: 0)
|
| 209 |
+
# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
|
| 210 |
+
# are ignored (-1)
|
| 211 |
+
_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
|
| 212 |
+
_C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
|
| 213 |
+
# Total number of RPN examples per image
|
| 214 |
+
_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
|
| 215 |
+
# Target fraction of foreground (positive) examples per RPN minibatch
|
| 216 |
+
_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
|
| 217 |
+
# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
|
| 218 |
+
_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
|
| 219 |
+
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
|
| 220 |
+
_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
|
| 221 |
+
_C.MODEL.RPN.LOSS_WEIGHT = 1.0
|
| 222 |
+
# Number of top scoring RPN proposals to keep before applying NMS
|
| 223 |
+
# When FPN is used, this is *per FPN level* (not total)
|
| 224 |
+
_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
|
| 225 |
+
_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
|
| 226 |
+
# Number of top scoring RPN proposals to keep after applying NMS
|
| 227 |
+
# When FPN is used, this limit is applied per level and then again to the union
|
| 228 |
+
# of proposals from all levels
|
| 229 |
+
# NOTE: When FPN is used, the meaning of this config is different from Detectron1.
|
| 230 |
+
# It means per-batch topk in Detectron1, but per-image topk here.
|
| 231 |
+
# See "modeling/rpn/rpn_outputs.py" for details.
|
| 232 |
+
_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
|
| 233 |
+
_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
|
| 234 |
+
# NMS threshold used on RPN proposals
|
| 235 |
+
_C.MODEL.RPN.NMS_THRESH = 0.7
|
| 236 |
+
|
| 237 |
+
# ---------------------------------------------------------------------------- #
|
| 238 |
+
# ROI HEADS options
|
| 239 |
+
# ---------------------------------------------------------------------------- #
|
| 240 |
+
_C.MODEL.ROI_HEADS = CN()
|
| 241 |
+
_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
|
| 242 |
+
# Number of foreground classes
|
| 243 |
+
_C.MODEL.ROI_HEADS.NUM_CLASSES = 80
|
| 244 |
+
# Names of the input feature maps to be used by ROI heads
|
| 245 |
+
# Currently all heads (box, mask, ...) use the same input feature map list
|
| 246 |
+
# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
|
| 247 |
+
_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
|
| 248 |
+
# IOU overlap ratios [IOU_THRESHOLD]
|
| 249 |
+
# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
|
| 250 |
+
# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
|
| 251 |
+
_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
|
| 252 |
+
_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
|
| 253 |
+
# RoI minibatch size *per image* (number of regions of interest [ROIs])
|
| 254 |
+
# Total number of RoIs per training minibatch =
|
| 255 |
+
# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
|
| 256 |
+
# E.g., a common configuration is: 512 * 16 = 8192
|
| 257 |
+
_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
|
| 258 |
+
# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
|
| 259 |
+
_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
|
| 260 |
+
|
| 261 |
+
# Only used on test mode
|
| 262 |
+
|
| 263 |
+
# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
|
| 264 |
+
# balance obtaining high recall with not having too many low precision
|
| 265 |
+
# detections that will slow down inference post processing steps (like NMS)
|
| 266 |
+
# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
|
| 267 |
+
# inference.
|
| 268 |
+
_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
|
| 269 |
+
# Overlap threshold used for non-maximum suppression (suppress boxes with
|
| 270 |
+
# IoU >= this threshold)
|
| 271 |
+
_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
|
| 272 |
+
# If True, augment proposals with ground-truth boxes before sampling proposals to
|
| 273 |
+
# train ROI heads.
|
| 274 |
+
_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------- #
|
| 277 |
+
# Box Head
|
| 278 |
+
# ---------------------------------------------------------------------------- #
|
| 279 |
+
_C.MODEL.ROI_BOX_HEAD = CN()
|
| 280 |
+
# C4 don't use head name option
|
| 281 |
+
# Options for non-C4 models: FastRCNNConvFCHead,
|
| 282 |
+
_C.MODEL.ROI_BOX_HEAD.NAME = ""
|
| 283 |
+
# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
|
| 284 |
+
# These are empirically chosen to approximately lead to unit variance targets
|
| 285 |
+
_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
|
| 286 |
+
# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
|
| 287 |
+
_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
|
| 288 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
|
| 289 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
|
| 290 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
| 291 |
+
_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 292 |
+
|
| 293 |
+
_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
|
| 294 |
+
# Hidden layer dimension for FC layers in the RoI box head
|
| 295 |
+
_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
|
| 296 |
+
_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
|
| 297 |
+
# Channel dimension for Conv layers in the RoI box head
|
| 298 |
+
_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
|
| 299 |
+
# Normalization method for the convolution layers.
|
| 300 |
+
# Options: "" (no norm), "GN", "SyncBN".
|
| 301 |
+
_C.MODEL.ROI_BOX_HEAD.NORM = ""
|
| 302 |
+
# Whether to use class agnostic for bbox regression
|
| 303 |
+
_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
|
| 304 |
+
# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
|
| 305 |
+
_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
|
| 306 |
+
|
| 307 |
+
# ---------------------------------------------------------------------------- #
|
| 308 |
+
# Cascaded Box Head
|
| 309 |
+
# ---------------------------------------------------------------------------- #
|
| 310 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
|
| 311 |
+
# The number of cascade stages is implicitly defined by the length of the following two configs.
|
| 312 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
|
| 313 |
+
(10.0, 10.0, 5.0, 5.0),
|
| 314 |
+
(20.0, 20.0, 10.0, 10.0),
|
| 315 |
+
(30.0, 30.0, 15.0, 15.0),
|
| 316 |
+
)
|
| 317 |
+
_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------------- #
|
| 321 |
+
# Mask Head
|
| 322 |
+
# ---------------------------------------------------------------------------- #
|
| 323 |
+
_C.MODEL.ROI_MASK_HEAD = CN()
|
| 324 |
+
_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
|
| 325 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
|
| 326 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
|
| 327 |
+
_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
|
| 328 |
+
_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
|
| 329 |
+
# Normalization method for the convolution layers.
|
| 330 |
+
# Options: "" (no norm), "GN", "SyncBN".
|
| 331 |
+
_C.MODEL.ROI_MASK_HEAD.NORM = ""
|
| 332 |
+
# Whether to use class agnostic for mask prediction
|
| 333 |
+
_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
|
| 334 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
| 335 |
+
_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ---------------------------------------------------------------------------- #
|
| 339 |
+
# Keypoint Head
|
| 340 |
+
# ---------------------------------------------------------------------------- #
|
| 341 |
+
_C.MODEL.ROI_KEYPOINT_HEAD = CN()
|
| 342 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
|
| 343 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
|
| 344 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
|
| 345 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
|
| 346 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
|
| 347 |
+
|
| 348 |
+
# Images with too few (or no) keypoints are excluded from training.
|
| 349 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
|
| 350 |
+
# Normalize by the total number of visible keypoints in the minibatch if True.
|
| 351 |
+
# Otherwise, normalize by the total number of keypoints that could ever exist
|
| 352 |
+
# in the minibatch.
|
| 353 |
+
# The keypoint softmax loss is only calculated on visible keypoints.
|
| 354 |
+
# Since the number of visible keypoints can vary significantly between
|
| 355 |
+
# minibatches, this has the effect of up-weighting the importance of
|
| 356 |
+
# minibatches with few visible keypoints. (Imagine the extreme case of
|
| 357 |
+
# only one visible keypoint versus N: in the case of N, each one
|
| 358 |
+
# contributes 1/N to the gradient compared to the single keypoint
|
| 359 |
+
# determining the gradient direction). Instead, we can normalize the
|
| 360 |
+
# loss by the total number of keypoints, if it were the case that all
|
| 361 |
+
# keypoints were visible in a full minibatch. (Returning to the example,
|
| 362 |
+
# this means that the one visible keypoint contributes as much as each
|
| 363 |
+
# of the N keypoints.)
|
| 364 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
|
| 365 |
+
# Multi-task loss weight to use for keypoints
|
| 366 |
+
# Recommended values:
|
| 367 |
+
# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
|
| 368 |
+
# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
|
| 369 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
|
| 370 |
+
# Type of pooling operation applied to the incoming feature map for each RoI
|
| 371 |
+
_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 372 |
+
|
| 373 |
+
# ---------------------------------------------------------------------------- #
|
| 374 |
+
# Semantic Segmentation Head
|
| 375 |
+
# ---------------------------------------------------------------------------- #
|
| 376 |
+
_C.MODEL.SEM_SEG_HEAD = CN()
|
| 377 |
+
_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
|
| 378 |
+
_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
|
| 379 |
+
# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
|
| 380 |
+
# the correposnding pixel.
|
| 381 |
+
_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
|
| 382 |
+
# Number of classes in the semantic segmentation head
|
| 383 |
+
_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
|
| 384 |
+
# Number of channels in the 3x3 convs inside semantic-FPN heads.
|
| 385 |
+
_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
|
| 386 |
+
# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
|
| 387 |
+
_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
|
| 388 |
+
# Normalization method for the convolution layers. Options: "" (no norm), "GN".
|
| 389 |
+
_C.MODEL.SEM_SEG_HEAD.NORM = "GN"
|
| 390 |
+
_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
|
| 391 |
+
|
| 392 |
+
_C.MODEL.PANOPTIC_FPN = CN()
|
| 393 |
+
# Scaling of all losses from instance detection / segmentation head.
|
| 394 |
+
_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
|
| 395 |
+
|
| 396 |
+
# options when combining instance & semantic segmentation outputs
|
| 397 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True})
|
| 398 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
|
| 399 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
|
| 400 |
+
_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# ---------------------------------------------------------------------------- #
|
| 404 |
+
# RetinaNet Head
|
| 405 |
+
# ---------------------------------------------------------------------------- #
|
| 406 |
+
_C.MODEL.RETINANET = CN()
|
| 407 |
+
|
| 408 |
+
# This is the number of foreground classes.
|
| 409 |
+
_C.MODEL.RETINANET.NUM_CLASSES = 80
|
| 410 |
+
|
| 411 |
+
_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
|
| 412 |
+
|
| 413 |
+
# Convolutions to use in the cls and bbox tower
|
| 414 |
+
# NOTE: this doesn't include the last conv for logits
|
| 415 |
+
_C.MODEL.RETINANET.NUM_CONVS = 4
|
| 416 |
+
|
| 417 |
+
# IoU overlap ratio [bg, fg] for labeling anchors.
|
| 418 |
+
# Anchors with < bg are labeled negative (0)
|
| 419 |
+
# Anchors with >= bg and < fg are ignored (-1)
|
| 420 |
+
# Anchors with >= fg are labeled positive (1)
|
| 421 |
+
_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
|
| 422 |
+
_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
|
| 423 |
+
|
| 424 |
+
# Prior prob for rare case (i.e. foreground) at the beginning of training.
|
| 425 |
+
# This is used to set the bias for the logits layer of the classifier subnet.
|
| 426 |
+
# This improves training stability in the case of heavy class imbalance.
|
| 427 |
+
_C.MODEL.RETINANET.PRIOR_PROB = 0.01
|
| 428 |
+
|
| 429 |
+
# Inference cls score threshold, only anchors with score > INFERENCE_TH are
|
| 430 |
+
# considered for inference (to improve speed)
|
| 431 |
+
_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
|
| 432 |
+
_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
|
| 433 |
+
_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
|
| 434 |
+
|
| 435 |
+
# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
|
| 436 |
+
_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
|
| 437 |
+
|
| 438 |
+
# Loss parameters
|
| 439 |
+
_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
|
| 440 |
+
_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
|
| 441 |
+
_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# ---------------------------------------------------------------------------- #
|
| 445 |
+
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
|
| 446 |
+
# Note that parts of a resnet may be used for both the backbone and the head
|
| 447 |
+
# These options apply to both
|
| 448 |
+
# ---------------------------------------------------------------------------- #
|
| 449 |
+
_C.MODEL.RESNETS = CN()
|
| 450 |
+
|
| 451 |
+
_C.MODEL.RESNETS.DEPTH = 50
|
| 452 |
+
_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
|
| 453 |
+
|
| 454 |
+
# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
|
| 455 |
+
_C.MODEL.RESNETS.NUM_GROUPS = 1
|
| 456 |
+
|
| 457 |
+
# Options: FrozenBN, GN, "SyncBN", "BN"
|
| 458 |
+
_C.MODEL.RESNETS.NORM = "FrozenBN"
|
| 459 |
+
|
| 460 |
+
# Baseline width of each group.
|
| 461 |
+
# Scaling this parameters will scale the width of all bottleneck layers.
|
| 462 |
+
_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
|
| 463 |
+
|
| 464 |
+
# Place the stride 2 conv on the 1x1 filter
|
| 465 |
+
# Use True only for the original MSRA ResNet; use False for C2 and Torch models
|
| 466 |
+
_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
|
| 467 |
+
|
| 468 |
+
# Apply dilation in stage "res5"
|
| 469 |
+
_C.MODEL.RESNETS.RES5_DILATION = 1
|
| 470 |
+
|
| 471 |
+
# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
|
| 472 |
+
# For R18 and R34, this needs to be set to 64
|
| 473 |
+
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
|
| 474 |
+
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
|
| 475 |
+
|
| 476 |
+
# Apply Deformable Convolution in stages
|
| 477 |
+
# Specify if apply deform_conv on Res2, Res3, Res4, Res5
|
| 478 |
+
_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
|
| 479 |
+
# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
|
| 480 |
+
# Use False for DeformableV1.
|
| 481 |
+
_C.MODEL.RESNETS.DEFORM_MODULATED = False
|
| 482 |
+
# Number of groups in deformable conv.
|
| 483 |
+
_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# ---------------------------------------------------------------------------- #
|
| 487 |
+
# Solver
|
| 488 |
+
# ---------------------------------------------------------------------------- #
|
| 489 |
+
_C.SOLVER = CN()
|
| 490 |
+
|
| 491 |
+
# See detectron2/solver/build.py for LR scheduler options
|
| 492 |
+
_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
|
| 493 |
+
|
| 494 |
+
_C.SOLVER.MAX_ITER = 40000
|
| 495 |
+
|
| 496 |
+
_C.SOLVER.BASE_LR = 0.001
|
| 497 |
+
|
| 498 |
+
_C.SOLVER.MOMENTUM = 0.9
|
| 499 |
+
|
| 500 |
+
_C.SOLVER.NESTEROV = False
|
| 501 |
+
|
| 502 |
+
_C.SOLVER.WEIGHT_DECAY = 0.0001
|
| 503 |
+
# The weight decay that's applied to parameters of normalization layers
|
| 504 |
+
# (typically the affine transformation)
|
| 505 |
+
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
|
| 506 |
+
|
| 507 |
+
_C.SOLVER.GAMMA = 0.1
|
| 508 |
+
# The iteration number to decrease learning rate by GAMMA.
|
| 509 |
+
_C.SOLVER.STEPS = (30000,)
|
| 510 |
+
|
| 511 |
+
_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
|
| 512 |
+
_C.SOLVER.WARMUP_ITERS = 1000
|
| 513 |
+
_C.SOLVER.WARMUP_METHOD = "linear"
|
| 514 |
+
|
| 515 |
+
# Save a checkpoint after every this number of iterations
|
| 516 |
+
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
| 517 |
+
|
| 518 |
+
# Number of images per batch across all machines.
|
| 519 |
+
# If we have 16 GPUs and IMS_PER_BATCH = 32,
|
| 520 |
+
# each GPU will see 2 images per batch.
|
| 521 |
+
_C.SOLVER.IMS_PER_BATCH = 16
|
| 522 |
+
|
| 523 |
+
# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
|
| 524 |
+
# biases. This is not useful (at least for recent models). You should avoid
|
| 525 |
+
# changing these and they exist only to reproduce Detectron v1 training if
|
| 526 |
+
# desired.
|
| 527 |
+
_C.SOLVER.BIAS_LR_FACTOR = 1.0
|
| 528 |
+
_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
|
| 529 |
+
|
| 530 |
+
# Gradient clipping
|
| 531 |
+
_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
|
| 532 |
+
# Type of gradient clipping, currently 2 values are supported:
|
| 533 |
+
# - "value": the absolute values of elements of each gradients are clipped
|
| 534 |
+
# - "norm": the norm of the gradient for each parameter is clipped thus
|
| 535 |
+
# affecting all elements in the parameter
|
| 536 |
+
_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
|
| 537 |
+
# Maximum absolute value used for clipping gradients
|
| 538 |
+
_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
|
| 539 |
+
# Floating point number p for L-p norm to be used with the "norm"
|
| 540 |
+
# gradient clipping type; for L-inf, please specify .inf
|
| 541 |
+
_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
|
| 542 |
+
|
| 543 |
+
# ---------------------------------------------------------------------------- #
|
| 544 |
+
# Specific test options
|
| 545 |
+
# ---------------------------------------------------------------------------- #
|
| 546 |
+
_C.TEST = CN()
|
| 547 |
+
# For end-to-end tests to verify the expected accuracy.
|
| 548 |
+
# Each item is [task, metric, value, tolerance]
|
| 549 |
+
# e.g.: [['bbox', 'AP', 38.5, 0.2]]
|
| 550 |
+
_C.TEST.EXPECTED_RESULTS = []
|
| 551 |
+
# The period (in terms of steps) to evaluate the model during training.
|
| 552 |
+
# Set to 0 to disable.
|
| 553 |
+
_C.TEST.EVAL_PERIOD = 0
|
| 554 |
+
# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
|
| 555 |
+
# When empty it will use the defaults in COCO.
|
| 556 |
+
# Otherwise it should have the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
|
| 557 |
+
_C.TEST.KEYPOINT_OKS_SIGMAS = []
|
| 558 |
+
# Maximum number of detections to return per image during inference (100 is
|
| 559 |
+
# based on the limit established for the COCO dataset).
|
| 560 |
+
_C.TEST.DETECTIONS_PER_IMAGE = 100
|
| 561 |
+
|
| 562 |
+
_C.TEST.AUG = CN({"ENABLED": False})
|
| 563 |
+
_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
|
| 564 |
+
_C.TEST.AUG.MAX_SIZE = 4000
|
| 565 |
+
_C.TEST.AUG.FLIP = True
|
| 566 |
+
|
| 567 |
+
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
|
| 568 |
+
_C.TEST.PRECISE_BN.NUM_ITER = 200
|
| 569 |
+
|
| 570 |
+
# ---------------------------------------------------------------------------- #
|
| 571 |
+
# Misc options
|
| 572 |
+
# ---------------------------------------------------------------------------- #
|
| 573 |
+
# Directory where output files are written
|
| 574 |
+
_C.OUTPUT_DIR = "./output"
|
| 575 |
+
# Set seed to negative to fully randomize everything.
|
| 576 |
+
# Set seed to positive to use a fixed seed. Note that a fixed seed increases
|
| 577 |
+
# reproducibility but does not guarantee fully deterministic behavior.
|
| 578 |
+
# Disabling all parallelism further increases reproducibility.
|
| 579 |
+
_C.SEED = -1
|
| 580 |
+
# Benchmark different cudnn algorithms.
|
| 581 |
+
# If input images have very different sizes, this option will have large overhead
|
| 582 |
+
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
|
| 583 |
+
# If input images have the same or similar sizes, benchmark is often helpful.
|
| 584 |
+
_C.CUDNN_BENCHMARK = False
|
| 585 |
+
# The period (in terms of steps) for minibatch visualization at train time.
|
| 586 |
+
# Set to 0 to disable.
|
| 587 |
+
_C.VIS_PERIOD = 0
|
| 588 |
+
|
| 589 |
+
# global config is for quick hack purposes.
|
| 590 |
+
# You can set them in command line or config files,
|
| 591 |
+
# and access it with:
|
| 592 |
+
#
|
| 593 |
+
# from detectron2.config import global_cfg
|
| 594 |
+
# print(global_cfg.HACK)
|
| 595 |
+
#
|
| 596 |
+
# Do not commit any configs into it.
|
| 597 |
+
_C.GLOBAL = CN()
|
| 598 |
+
_C.GLOBAL.HACK = 1.0
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from . import transforms # isort:skip
|
| 3 |
+
|
| 4 |
+
from .build import (
|
| 5 |
+
build_detection_test_loader,
|
| 6 |
+
build_detection_train_loader,
|
| 7 |
+
get_detection_dataset_dicts,
|
| 8 |
+
load_proposals_into_dataset,
|
| 9 |
+
print_instances_class_histogram,
|
| 10 |
+
)
|
| 11 |
+
from .catalog import DatasetCatalog, MetadataCatalog
|
| 12 |
+
from .common import DatasetFromList, MapDataset
|
| 13 |
+
from .dataset_mapper import DatasetMapper
|
| 14 |
+
|
| 15 |
+
# ensure the builtin data are registered
|
| 16 |
+
from . import datasets, samplers # isort:skip
|
| 17 |
+
|
| 18 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/build.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import bisect
|
| 3 |
+
import copy
|
| 4 |
+
import itertools
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
import operator
|
| 8 |
+
import pickle
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
from fvcore.common.file_io import PathManager
|
| 11 |
+
from tabulate import tabulate
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
|
| 14 |
+
from detectron2.structures import BoxMode
|
| 15 |
+
from detectron2.utils.comm import get_world_size
|
| 16 |
+
from detectron2.utils.env import seed_all_rng
|
| 17 |
+
from detectron2.utils.logger import log_first_n
|
| 18 |
+
|
| 19 |
+
from . import samplers
|
| 20 |
+
from .catalog import DatasetCatalog, MetadataCatalog
|
| 21 |
+
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
|
| 22 |
+
from .dataset_mapper import DatasetMapper
|
| 23 |
+
from .detection_utils import check_metadata_consistency
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
This file contains the default logic to build a dataloader for training or testing.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"build_detection_train_loader",
|
| 31 |
+
"build_detection_test_loader",
|
| 32 |
+
"get_detection_dataset_dicts",
|
| 33 |
+
"load_proposals_into_dataset",
|
| 34 |
+
"print_instances_class_histogram",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts):
|
| 39 |
+
"""
|
| 40 |
+
Filter out images with none annotations or only crowd annotations
|
| 41 |
+
(i.e., images without non-crowd annotations).
|
| 42 |
+
A common training-time preprocessing on COCO dataset.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
list[dict]: the same format, but filtered.
|
| 49 |
+
"""
|
| 50 |
+
num_before = len(dataset_dicts)
|
| 51 |
+
|
| 52 |
+
def valid(anns):
|
| 53 |
+
for ann in anns:
|
| 54 |
+
if ann.get("iscrowd", 0) == 0:
|
| 55 |
+
return True
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
| 59 |
+
num_after = len(dataset_dicts)
|
| 60 |
+
logger = logging.getLogger(__name__)
|
| 61 |
+
logger.info(
|
| 62 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
| 63 |
+
num_before - num_after, num_after
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
return dataset_dicts
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
|
| 70 |
+
"""
|
| 71 |
+
Filter out images with too few number of keypoints.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
list[dict]: the same format as dataset_dicts, but filtered.
|
| 78 |
+
"""
|
| 79 |
+
num_before = len(dataset_dicts)
|
| 80 |
+
|
| 81 |
+
def visible_keypoints_in_image(dic):
|
| 82 |
+
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
|
| 83 |
+
annotations = dic["annotations"]
|
| 84 |
+
return sum(
|
| 85 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
| 86 |
+
for ann in annotations
|
| 87 |
+
if "keypoints" in ann
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
dataset_dicts = [
|
| 91 |
+
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
|
| 92 |
+
]
|
| 93 |
+
num_after = len(dataset_dicts)
|
| 94 |
+
logger = logging.getLogger(__name__)
|
| 95 |
+
logger.info(
|
| 96 |
+
"Removed {} images with fewer than {} keypoints.".format(
|
| 97 |
+
num_before - num_after, min_keypoints_per_image
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
return dataset_dicts
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_proposals_into_dataset(dataset_dicts, proposal_file):
|
| 104 |
+
"""
|
| 105 |
+
Load precomputed object proposals into the dataset.
|
| 106 |
+
|
| 107 |
+
The proposal file should be a pickled dict with the following keys:
|
| 108 |
+
|
| 109 |
+
- "ids": list[int] or list[str], the image ids
|
| 110 |
+
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
|
| 111 |
+
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
|
| 112 |
+
corresponding to the boxes.
|
| 113 |
+
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
| 117 |
+
proposal_file (str): file path of pre-computed proposals, in pkl format.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
list[dict]: the same format as dataset_dicts, but added proposal field.
|
| 121 |
+
"""
|
| 122 |
+
logger = logging.getLogger(__name__)
|
| 123 |
+
logger.info("Loading proposals from: {}".format(proposal_file))
|
| 124 |
+
|
| 125 |
+
with PathManager.open(proposal_file, "rb") as f:
|
| 126 |
+
proposals = pickle.load(f, encoding="latin1")
|
| 127 |
+
|
| 128 |
+
# Rename the key names in D1 proposal files
|
| 129 |
+
rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
|
| 130 |
+
for key in rename_keys:
|
| 131 |
+
if key in proposals:
|
| 132 |
+
proposals[rename_keys[key]] = proposals.pop(key)
|
| 133 |
+
|
| 134 |
+
# Fetch the indexes of all proposals that are in the dataset
|
| 135 |
+
# Convert image_id to str since they could be int.
|
| 136 |
+
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
|
| 137 |
+
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
|
| 138 |
+
|
| 139 |
+
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
|
| 140 |
+
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
|
| 141 |
+
|
| 142 |
+
for record in dataset_dicts:
|
| 143 |
+
# Get the index of the proposal
|
| 144 |
+
i = id_to_index[str(record["image_id"])]
|
| 145 |
+
|
| 146 |
+
boxes = proposals["boxes"][i]
|
| 147 |
+
objectness_logits = proposals["objectness_logits"][i]
|
| 148 |
+
# Sort the proposals in descending order of the scores
|
| 149 |
+
inds = objectness_logits.argsort()[::-1]
|
| 150 |
+
record["proposal_boxes"] = boxes[inds]
|
| 151 |
+
record["proposal_objectness_logits"] = objectness_logits[inds]
|
| 152 |
+
record["proposal_bbox_mode"] = bbox_mode
|
| 153 |
+
|
| 154 |
+
return dataset_dicts
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _quantize(x, bin_edges):
|
| 158 |
+
bin_edges = copy.copy(bin_edges)
|
| 159 |
+
bin_edges = sorted(bin_edges)
|
| 160 |
+
quantized = list(map(lambda y: bisect.bisect_right(bin_edges, y), x))
|
| 161 |
+
return quantized
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def print_instances_class_histogram(dataset_dicts, class_names):
|
| 165 |
+
"""
|
| 166 |
+
Args:
|
| 167 |
+
dataset_dicts (list[dict]): list of dataset dicts.
|
| 168 |
+
class_names (list[str]): list of class names (zero-indexed).
|
| 169 |
+
"""
|
| 170 |
+
num_classes = len(class_names)
|
| 171 |
+
hist_bins = np.arange(num_classes + 1)
|
| 172 |
+
histogram = np.zeros((num_classes,), dtype=np.int)
|
| 173 |
+
for entry in dataset_dicts:
|
| 174 |
+
annos = entry["annotations"]
|
| 175 |
+
classes = [x["category_id"] for x in annos if not x.get("iscrowd", 0)]
|
| 176 |
+
histogram += np.histogram(classes, bins=hist_bins)[0]
|
| 177 |
+
|
| 178 |
+
N_COLS = min(6, len(class_names) * 2)
|
| 179 |
+
|
| 180 |
+
def short_name(x):
|
| 181 |
+
# make long class names shorter. useful for lvis
|
| 182 |
+
if len(x) > 13:
|
| 183 |
+
return x[:11] + ".."
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
data = list(
|
| 187 |
+
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
|
| 188 |
+
)
|
| 189 |
+
total_num_instances = sum(data[1::2])
|
| 190 |
+
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
|
| 191 |
+
if num_classes > 1:
|
| 192 |
+
data.extend(["total", total_num_instances])
|
| 193 |
+
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
|
| 194 |
+
table = tabulate(
|
| 195 |
+
data,
|
| 196 |
+
headers=["category", "#instances"] * (N_COLS // 2),
|
| 197 |
+
tablefmt="pipe",
|
| 198 |
+
numalign="left",
|
| 199 |
+
stralign="center",
|
| 200 |
+
)
|
| 201 |
+
log_first_n(
|
| 202 |
+
logging.INFO,
|
| 203 |
+
"Distribution of instances among all {} categories:\n".format(num_classes)
|
| 204 |
+
+ colored(table, "cyan"),
|
| 205 |
+
key="message",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_detection_dataset_dicts(
|
| 210 |
+
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
| 211 |
+
):
|
| 212 |
+
"""
|
| 213 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
dataset_names (list[str]): a list of dataset names
|
| 217 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
| 218 |
+
min_keypoints (int): filter out images with fewer keypoints than
|
| 219 |
+
`min_keypoints`. Set to 0 to do nothing.
|
| 220 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
| 221 |
+
that match each dataset in `dataset_names`.
|
| 222 |
+
"""
|
| 223 |
+
assert len(dataset_names)
|
| 224 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
| 225 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
| 226 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
| 227 |
+
|
| 228 |
+
if proposal_files is not None:
|
| 229 |
+
assert len(dataset_names) == len(proposal_files)
|
| 230 |
+
# load precomputed proposals from proposal files
|
| 231 |
+
dataset_dicts = [
|
| 232 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
| 233 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
| 237 |
+
|
| 238 |
+
has_instances = "annotations" in dataset_dicts[0]
|
| 239 |
+
# Keep images without instance-level GT if the dataset has semantic labels.
|
| 240 |
+
if filter_empty and has_instances and "sem_seg_file_name" not in dataset_dicts[0]:
|
| 241 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
| 242 |
+
|
| 243 |
+
if min_keypoints > 0 and has_instances:
|
| 244 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
| 245 |
+
|
| 246 |
+
if has_instances:
|
| 247 |
+
try:
|
| 248 |
+
class_names = MetadataCatalog.get(dataset_names[0]).thing_classes
|
| 249 |
+
check_metadata_consistency("thing_classes", dataset_names)
|
| 250 |
+
print_instances_class_histogram(dataset_dicts, class_names)
|
| 251 |
+
except AttributeError: # class names are not available for this dataset
|
| 252 |
+
pass
|
| 253 |
+
return dataset_dicts
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def build_detection_train_loader(cfg, mapper=None):
|
| 257 |
+
"""
|
| 258 |
+
A data loader is created by the following steps:
|
| 259 |
+
|
| 260 |
+
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
|
| 261 |
+
2. Coordinate a random shuffle order shared among all processes (all GPUs)
|
| 262 |
+
3. Each process spawn another few workers to process the dicts. Each worker will:
|
| 263 |
+
* Map each metadata dict into another format to be consumed by the model.
|
| 264 |
+
* Batch them by simply putting dicts into a list.
|
| 265 |
+
|
| 266 |
+
The batched ``list[mapped_dict]`` is what this dataloader will yield.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
cfg (CfgNode): the config
|
| 270 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 271 |
+
returns the format to be consumed by the model.
|
| 272 |
+
By default it will be `DatasetMapper(cfg, True)`.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
an infinite iterator of training data
|
| 276 |
+
"""
|
| 277 |
+
num_workers = get_world_size()
|
| 278 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
| 279 |
+
assert (
|
| 280 |
+
images_per_batch % num_workers == 0
|
| 281 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
| 282 |
+
images_per_batch, num_workers
|
| 283 |
+
)
|
| 284 |
+
assert (
|
| 285 |
+
images_per_batch >= num_workers
|
| 286 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
| 287 |
+
images_per_batch, num_workers
|
| 288 |
+
)
|
| 289 |
+
images_per_worker = images_per_batch // num_workers
|
| 290 |
+
|
| 291 |
+
dataset_dicts = get_detection_dataset_dicts(
|
| 292 |
+
cfg.DATASETS.TRAIN,
|
| 293 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
| 294 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 295 |
+
if cfg.MODEL.KEYPOINT_ON
|
| 296 |
+
else 0,
|
| 297 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 298 |
+
)
|
| 299 |
+
dataset = DatasetFromList(dataset_dicts, copy=False)
|
| 300 |
+
|
| 301 |
+
if mapper is None:
|
| 302 |
+
mapper = DatasetMapper(cfg, True)
|
| 303 |
+
dataset = MapDataset(dataset, mapper)
|
| 304 |
+
|
| 305 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
| 306 |
+
logger = logging.getLogger(__name__)
|
| 307 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
| 308 |
+
if sampler_name == "TrainingSampler":
|
| 309 |
+
sampler = samplers.TrainingSampler(len(dataset))
|
| 310 |
+
elif sampler_name == "RepeatFactorTrainingSampler":
|
| 311 |
+
sampler = samplers.RepeatFactorTrainingSampler(
|
| 312 |
+
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
| 316 |
+
|
| 317 |
+
if cfg.DATALOADER.ASPECT_RATIO_GROUPING:
|
| 318 |
+
data_loader = torch.utils.data.DataLoader(
|
| 319 |
+
dataset,
|
| 320 |
+
sampler=sampler,
|
| 321 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
| 322 |
+
batch_sampler=None,
|
| 323 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
| 324 |
+
worker_init_fn=worker_init_reset_seed,
|
| 325 |
+
) # yield individual mapped dict
|
| 326 |
+
data_loader = AspectRatioGroupedDataset(data_loader, images_per_worker)
|
| 327 |
+
else:
|
| 328 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
| 329 |
+
sampler, images_per_worker, drop_last=True
|
| 330 |
+
)
|
| 331 |
+
# drop_last so the batch always have the same size
|
| 332 |
+
data_loader = torch.utils.data.DataLoader(
|
| 333 |
+
dataset,
|
| 334 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
| 335 |
+
batch_sampler=batch_sampler,
|
| 336 |
+
collate_fn=trivial_batch_collator,
|
| 337 |
+
worker_init_fn=worker_init_reset_seed,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
return data_loader
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def build_detection_test_loader(cfg, dataset_name, mapper=None):
|
| 344 |
+
"""
|
| 345 |
+
Similar to `build_detection_train_loader`.
|
| 346 |
+
But this function uses the given `dataset_name` argument (instead of the names in cfg),
|
| 347 |
+
and uses batch size 1.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
cfg: a detectron2 CfgNode
|
| 351 |
+
dataset_name (str): a name of the dataset that's available in the DatasetCatalog
|
| 352 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
| 353 |
+
and returns the format to be consumed by the model.
|
| 354 |
+
By default it will be `DatasetMapper(cfg, False)`.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 358 |
+
dataset, with test-time transformation and batching.
|
| 359 |
+
"""
|
| 360 |
+
dataset_dicts = get_detection_dataset_dicts(
|
| 361 |
+
[dataset_name],
|
| 362 |
+
filter_empty=False,
|
| 363 |
+
proposal_files=[
|
| 364 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
|
| 365 |
+
]
|
| 366 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
| 367 |
+
else None,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
dataset = DatasetFromList(dataset_dicts)
|
| 371 |
+
if mapper is None:
|
| 372 |
+
mapper = DatasetMapper(cfg, False)
|
| 373 |
+
dataset = MapDataset(dataset, mapper)
|
| 374 |
+
|
| 375 |
+
sampler = samplers.InferenceSampler(len(dataset))
|
| 376 |
+
# Always use 1 image per worker during inference since this is the
|
| 377 |
+
# standard when reporting inference time in papers.
|
| 378 |
+
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
|
| 379 |
+
|
| 380 |
+
data_loader = torch.utils.data.DataLoader(
|
| 381 |
+
dataset,
|
| 382 |
+
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
| 383 |
+
batch_sampler=batch_sampler,
|
| 384 |
+
collate_fn=trivial_batch_collator,
|
| 385 |
+
)
|
| 386 |
+
return data_loader
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def trivial_batch_collator(batch):
|
| 390 |
+
"""
|
| 391 |
+
A batch collator that does nothing.
|
| 392 |
+
"""
|
| 393 |
+
return batch
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def worker_init_reset_seed(worker_id):
|
| 397 |
+
seed_all_rng(np.random.randint(2 ** 31) + worker_id)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/catalog.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import types
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from detectron2.utils.logger import log_first_n
|
| 8 |
+
|
| 9 |
+
__all__ = ["DatasetCatalog", "MetadataCatalog"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetCatalog(object):
|
| 13 |
+
"""
|
| 14 |
+
A catalog that stores information about the data and how to obtain them.
|
| 15 |
+
|
| 16 |
+
It contains a mapping from strings
|
| 17 |
+
(which are names that identify a dataset, e.g. "coco_2014_train")
|
| 18 |
+
to a function which parses the dataset and returns the samples in the
|
| 19 |
+
format of `list[dict]`.
|
| 20 |
+
|
| 21 |
+
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
|
| 22 |
+
if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
|
| 23 |
+
|
| 24 |
+
The purpose of having this catalog is to make it easy to choose
|
| 25 |
+
different data, by just using the strings in the config.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
_REGISTERED = {}
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def register(name, func):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 35 |
+
func (callable): a callable which takes no arguments and returns a list of dicts.
|
| 36 |
+
"""
|
| 37 |
+
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
|
| 38 |
+
assert name not in DatasetCatalog._REGISTERED, "Dataset '{}' is already registered!".format(
|
| 39 |
+
name
|
| 40 |
+
)
|
| 41 |
+
DatasetCatalog._REGISTERED[name] = func
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def get(name):
|
| 45 |
+
"""
|
| 46 |
+
Call the registered function and return its results.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
list[dict]: dataset annotations.0
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
f = DatasetCatalog._REGISTERED[name]
|
| 56 |
+
except KeyError:
|
| 57 |
+
raise KeyError(
|
| 58 |
+
"Dataset '{}' is not registered! Available data are: {}".format(
|
| 59 |
+
name, ", ".join(DatasetCatalog._REGISTERED.keys())
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
return f()
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def list() -> List[str]:
|
| 66 |
+
"""
|
| 67 |
+
List all registered data.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
list[str]
|
| 71 |
+
"""
|
| 72 |
+
return list(DatasetCatalog._REGISTERED.keys())
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def clear():
|
| 76 |
+
"""
|
| 77 |
+
Remove all registered dataset.
|
| 78 |
+
"""
|
| 79 |
+
DatasetCatalog._REGISTERED.clear()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Metadata(types.SimpleNamespace):
|
| 83 |
+
"""
|
| 84 |
+
A class that supports simple attribute setter/getter.
|
| 85 |
+
It is intended for storing metadata of a dataset and make it accessible globally.
|
| 86 |
+
|
| 87 |
+
Examples:
|
| 88 |
+
|
| 89 |
+
.. code-block:: python
|
| 90 |
+
|
| 91 |
+
# somewhere when you load the data:
|
| 92 |
+
MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
|
| 93 |
+
|
| 94 |
+
# somewhere when you print statistics or visualize:
|
| 95 |
+
classes = MetadataCatalog.get("mydataset").thing_classes
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# the name of the dataset
|
| 99 |
+
# set default to N/A so that `self.name` in the errors will not trigger getattr again
|
| 100 |
+
name: str = "N/A"
|
| 101 |
+
|
| 102 |
+
_RENAMED = {
|
| 103 |
+
"class_names": "thing_classes",
|
| 104 |
+
"dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
|
| 105 |
+
"stuff_class_names": "stuff_classes",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def __getattr__(self, key):
|
| 109 |
+
if key in self._RENAMED:
|
| 110 |
+
log_first_n(
|
| 111 |
+
logging.WARNING,
|
| 112 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
| 113 |
+
n=10,
|
| 114 |
+
)
|
| 115 |
+
return getattr(self, self._RENAMED[key])
|
| 116 |
+
|
| 117 |
+
raise AttributeError(
|
| 118 |
+
"Attribute '{}' does not exist in the metadata of '{}'. Available keys are {}.".format(
|
| 119 |
+
key, self.name, str(self.__dict__.keys())
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def __setattr__(self, key, val):
|
| 124 |
+
if key in self._RENAMED:
|
| 125 |
+
log_first_n(
|
| 126 |
+
logging.WARNING,
|
| 127 |
+
"Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
|
| 128 |
+
n=10,
|
| 129 |
+
)
|
| 130 |
+
setattr(self, self._RENAMED[key], val)
|
| 131 |
+
|
| 132 |
+
# Ensure that metadata of the same name stays consistent
|
| 133 |
+
try:
|
| 134 |
+
oldval = getattr(self, key)
|
| 135 |
+
assert oldval == val, (
|
| 136 |
+
"Attribute '{}' in the metadata of '{}' cannot be set "
|
| 137 |
+
"to a different value!\n{} != {}".format(key, self.name, oldval, val)
|
| 138 |
+
)
|
| 139 |
+
except AttributeError:
|
| 140 |
+
super().__setattr__(key, val)
|
| 141 |
+
|
| 142 |
+
def as_dict(self):
|
| 143 |
+
"""
|
| 144 |
+
Returns all the metadata as a dict.
|
| 145 |
+
Note that modifications to the returned dict will not reflect on the Metadata object.
|
| 146 |
+
"""
|
| 147 |
+
return copy.copy(self.__dict__)
|
| 148 |
+
|
| 149 |
+
def set(self, **kwargs):
|
| 150 |
+
"""
|
| 151 |
+
Set multiple metadata with kwargs.
|
| 152 |
+
"""
|
| 153 |
+
for k, v in kwargs.items():
|
| 154 |
+
setattr(self, k, v)
|
| 155 |
+
return self
|
| 156 |
+
|
| 157 |
+
def get(self, key, default=None):
|
| 158 |
+
"""
|
| 159 |
+
Access an attribute and return its value if exists.
|
| 160 |
+
Otherwise return default.
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
return getattr(self, key)
|
| 164 |
+
except AttributeError:
|
| 165 |
+
return default
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class MetadataCatalog:
|
| 169 |
+
"""
|
| 170 |
+
MetadataCatalog provides access to "Metadata" of a given dataset.
|
| 171 |
+
|
| 172 |
+
The metadata associated with a certain name is a singleton: once created,
|
| 173 |
+
the metadata will stay alive and will be returned by future calls to `get(name)`.
|
| 174 |
+
|
| 175 |
+
It's like global variables, so don't abuse it.
|
| 176 |
+
It's meant for storing knowledge that's constant and shared across the execution
|
| 177 |
+
of the program, e.g.: the class names in COCO.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
_NAME_TO_META = {}
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def get(name):
|
| 184 |
+
"""
|
| 185 |
+
Args:
|
| 186 |
+
name (str): name of a dataset (e.g. coco_2014_train).
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Metadata: The :class:`Metadata` instance associated with this name,
|
| 190 |
+
or create an empty one if none is available.
|
| 191 |
+
"""
|
| 192 |
+
assert len(name)
|
| 193 |
+
if name in MetadataCatalog._NAME_TO_META:
|
| 194 |
+
ret = MetadataCatalog._NAME_TO_META[name]
|
| 195 |
+
# TODO this is for the BC breaking change in D15247032.
|
| 196 |
+
# Remove this in the future.
|
| 197 |
+
if hasattr(ret, "dataset_name"):
|
| 198 |
+
logger = logging.getLogger()
|
| 199 |
+
logger.warning(
|
| 200 |
+
"""
|
| 201 |
+
The 'dataset_name' key in metadata is no longer used for
|
| 202 |
+
sharing metadata among splits after D15247032! Add
|
| 203 |
+
metadata to each split (now called dataset) separately!
|
| 204 |
+
"""
|
| 205 |
+
)
|
| 206 |
+
parent_meta = MetadataCatalog.get(ret.dataset_name).as_dict()
|
| 207 |
+
ret.set(**parent_meta)
|
| 208 |
+
return ret
|
| 209 |
+
else:
|
| 210 |
+
m = MetadataCatalog._NAME_TO_META[name] = Metadata(name=name)
|
| 211 |
+
return m
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def list():
|
| 215 |
+
"""
|
| 216 |
+
List all registered metadata.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
list[str]: keys (names of data) of all registered metadata
|
| 220 |
+
"""
|
| 221 |
+
return list(MetadataCatalog._NAME_TO_META.keys())
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/common.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pickle
|
| 6 |
+
import random
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
|
| 9 |
+
from detectron2.utils.serialize import PicklableWrapper
|
| 10 |
+
|
| 11 |
+
__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MapDataset(data.Dataset):
|
| 15 |
+
"""
|
| 16 |
+
Map a function over the elements in a dataset.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
dataset: a dataset where map function is applied.
|
| 20 |
+
map_func: a callable which maps the element in dataset. map_func is
|
| 21 |
+
responsible for error handling, when error happens, it needs to
|
| 22 |
+
return None so the MapDataset will randomly use other
|
| 23 |
+
elements from the dataset.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, dataset, map_func):
|
| 27 |
+
self._dataset = dataset
|
| 28 |
+
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
|
| 29 |
+
|
| 30 |
+
self._rng = random.Random(42)
|
| 31 |
+
self._fallback_candidates = set(range(len(dataset)))
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self._dataset)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
retry_count = 0
|
| 38 |
+
cur_idx = int(idx)
|
| 39 |
+
|
| 40 |
+
while True:
|
| 41 |
+
data = self._map_func(self._dataset[cur_idx])
|
| 42 |
+
if data is not None:
|
| 43 |
+
self._fallback_candidates.add(cur_idx)
|
| 44 |
+
return data
|
| 45 |
+
|
| 46 |
+
# _map_func fails for this idx, use a random new index from the pool
|
| 47 |
+
retry_count += 1
|
| 48 |
+
self._fallback_candidates.discard(cur_idx)
|
| 49 |
+
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
|
| 50 |
+
|
| 51 |
+
if retry_count >= 3:
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
logger.warning(
|
| 54 |
+
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
|
| 55 |
+
idx, retry_count
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DatasetFromList(data.Dataset):
|
| 61 |
+
"""
|
| 62 |
+
Wrap a list to a torch Dataset. It produces elements of the list as data.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, lst: list, copy: bool = True, serialize: bool = True):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
lst (list): a list which contains elements to produce.
|
| 69 |
+
copy (bool): whether to deepcopy the element when producing it,
|
| 70 |
+
so that the result can be modified in place without affecting the
|
| 71 |
+
source in the list.
|
| 72 |
+
serialize (bool): whether to hold memory using serialized objects, when
|
| 73 |
+
enabled, data loader workers can use shared RAM from master
|
| 74 |
+
process instead of making a copy.
|
| 75 |
+
"""
|
| 76 |
+
self._lst = lst
|
| 77 |
+
self._copy = copy
|
| 78 |
+
self._serialize = serialize
|
| 79 |
+
|
| 80 |
+
def _serialize(data):
|
| 81 |
+
buffer = pickle.dumps(data, protocol=-1)
|
| 82 |
+
return np.frombuffer(buffer, dtype=np.uint8)
|
| 83 |
+
|
| 84 |
+
if self._serialize:
|
| 85 |
+
logger = logging.getLogger(__name__)
|
| 86 |
+
logger.info(
|
| 87 |
+
"Serializing {} elements to byte tensors and concatenating them all ...".format(
|
| 88 |
+
len(self._lst)
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
self._lst = [_serialize(x) for x in self._lst]
|
| 92 |
+
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
|
| 93 |
+
self._addr = np.cumsum(self._addr)
|
| 94 |
+
self._lst = np.concatenate(self._lst)
|
| 95 |
+
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2))
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
if self._serialize:
|
| 99 |
+
return len(self._addr)
|
| 100 |
+
else:
|
| 101 |
+
return len(self._lst)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
if self._serialize:
|
| 105 |
+
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
|
| 106 |
+
end_addr = self._addr[idx].item()
|
| 107 |
+
bytes = memoryview(self._lst[start_addr:end_addr])
|
| 108 |
+
return pickle.loads(bytes)
|
| 109 |
+
elif self._copy:
|
| 110 |
+
return copy.deepcopy(self._lst[idx])
|
| 111 |
+
else:
|
| 112 |
+
return self._lst[idx]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AspectRatioGroupedDataset(data.IterableDataset):
|
| 116 |
+
"""
|
| 117 |
+
Batch data that have similar aspect ratio together.
|
| 118 |
+
In this implementation, images whose aspect ratio < (or >) 1 will
|
| 119 |
+
be batched together.
|
| 120 |
+
This improves training speed because the images then need less padding
|
| 121 |
+
to form a batch.
|
| 122 |
+
|
| 123 |
+
It assumes the underlying dataset produces dicts with "width" and "height" keys.
|
| 124 |
+
It will then produce a list of original dicts with length = batch_size,
|
| 125 |
+
all with similar aspect ratios.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, dataset, batch_size):
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
dataset: an iterable. Each element must be a dict with keys
|
| 132 |
+
"width" and "height", which will be used to batch data.
|
| 133 |
+
batch_size (int):
|
| 134 |
+
"""
|
| 135 |
+
self.dataset = dataset
|
| 136 |
+
self.batch_size = batch_size
|
| 137 |
+
self._buckets = [[] for _ in range(2)]
|
| 138 |
+
# Hard-coded two aspect ratio groups: w > h and w < h.
|
| 139 |
+
# Can add support for more aspect ratio groups, but doesn't seem useful
|
| 140 |
+
|
| 141 |
+
def __iter__(self):
|
| 142 |
+
for d in self.dataset:
|
| 143 |
+
w, h = d["width"], d["height"]
|
| 144 |
+
bucket_id = 0 if w > h else 1
|
| 145 |
+
bucket = self._buckets[bucket_id]
|
| 146 |
+
bucket.append(d)
|
| 147 |
+
if len(bucket) == self.batch_size:
|
| 148 |
+
yield bucket[:]
|
| 149 |
+
del bucket[:]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/dataset_mapper.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from fvcore.common.file_io import PathManager
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from . import detection_utils as utils
|
| 10 |
+
from . import transforms as T
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This file contains the default mapping that's applied to "dataset dicts".
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
__all__ = ["DatasetMapper"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DatasetMapper:
|
| 20 |
+
"""
|
| 21 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
| 22 |
+
and map it into a format used by the model.
|
| 23 |
+
|
| 24 |
+
This is the default callable to be used to map your dataset dict into training data.
|
| 25 |
+
You may need to follow it to implement your own one for customized logic,
|
| 26 |
+
such as a different way to read or transform images.
|
| 27 |
+
See :doc:`/tutorials/data_loading` for details.
|
| 28 |
+
|
| 29 |
+
The callable currently does the following:
|
| 30 |
+
|
| 31 |
+
1. Read the image from "file_name"
|
| 32 |
+
2. Applies cropping/geometric transforms to the image and annotations
|
| 33 |
+
3. Prepare data and annotations to Tensor and :class:`Instances`
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, cfg, is_train=True):
|
| 37 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
| 38 |
+
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
|
| 39 |
+
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
|
| 40 |
+
else:
|
| 41 |
+
self.crop_gen = None
|
| 42 |
+
|
| 43 |
+
self.tfm_gens = utils.build_transform_gen(cfg, is_train)
|
| 44 |
+
|
| 45 |
+
# fmt: off
|
| 46 |
+
self.img_format = cfg.INPUT.FORMAT
|
| 47 |
+
self.mask_on = cfg.MODEL.MASK_ON
|
| 48 |
+
self.mask_format = cfg.INPUT.MASK_FORMAT
|
| 49 |
+
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
| 50 |
+
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
|
| 51 |
+
# fmt: on
|
| 52 |
+
if self.keypoint_on and is_train:
|
| 53 |
+
# Flip only makes sense in training
|
| 54 |
+
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
| 55 |
+
else:
|
| 56 |
+
self.keypoint_hflip_indices = None
|
| 57 |
+
|
| 58 |
+
if self.load_proposals:
|
| 59 |
+
self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
|
| 60 |
+
self.proposal_topk = (
|
| 61 |
+
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
| 62 |
+
if is_train
|
| 63 |
+
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
| 64 |
+
)
|
| 65 |
+
self.is_train = is_train
|
| 66 |
+
|
| 67 |
+
def __call__(self, dataset_dict):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
dict: a format that builtin models in detectron2 accept
|
| 74 |
+
"""
|
| 75 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 76 |
+
# USER: Write your own image loading if it's not from a file
|
| 77 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
| 78 |
+
utils.check_image_size(dataset_dict, image)
|
| 79 |
+
|
| 80 |
+
if "annotations" not in dataset_dict:
|
| 81 |
+
image, transforms = T.apply_transform_gens(
|
| 82 |
+
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
# Crop around an instance if there are instances in the image.
|
| 86 |
+
# USER: Remove if you don't use cropping
|
| 87 |
+
if self.crop_gen:
|
| 88 |
+
crop_tfm = utils.gen_crop_transform_with_instance(
|
| 89 |
+
self.crop_gen.get_crop_size(image.shape[:2]),
|
| 90 |
+
image.shape[:2],
|
| 91 |
+
np.random.choice(dataset_dict["annotations"]),
|
| 92 |
+
)
|
| 93 |
+
image = crop_tfm.apply_image(image)
|
| 94 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
| 95 |
+
if self.crop_gen:
|
| 96 |
+
transforms = crop_tfm + transforms
|
| 97 |
+
|
| 98 |
+
image_shape = image.shape[:2] # h, w
|
| 99 |
+
|
| 100 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
| 101 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
| 102 |
+
# Therefore it's important to use torch.Tensor.
|
| 103 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
| 104 |
+
|
| 105 |
+
# USER: Remove if you don't use pre-computed proposals.
|
| 106 |
+
if self.load_proposals:
|
| 107 |
+
utils.transform_proposals(
|
| 108 |
+
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if not self.is_train:
|
| 112 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 113 |
+
dataset_dict.pop("annotations", None)
|
| 114 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
| 115 |
+
return dataset_dict
|
| 116 |
+
|
| 117 |
+
if "annotations" in dataset_dict:
|
| 118 |
+
# USER: Modify this if you want to keep them for some reason.
|
| 119 |
+
for anno in dataset_dict["annotations"]:
|
| 120 |
+
if not self.mask_on:
|
| 121 |
+
anno.pop("segmentation", None)
|
| 122 |
+
if not self.keypoint_on:
|
| 123 |
+
anno.pop("keypoints", None)
|
| 124 |
+
|
| 125 |
+
# USER: Implement additional transformations if you have other types of data
|
| 126 |
+
annos = [
|
| 127 |
+
utils.transform_instance_annotations(
|
| 128 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
| 129 |
+
)
|
| 130 |
+
for obj in dataset_dict.pop("annotations")
|
| 131 |
+
if obj.get("iscrowd", 0) == 0
|
| 132 |
+
]
|
| 133 |
+
instances = utils.annotations_to_instances(
|
| 134 |
+
annos, image_shape, mask_format=self.mask_format
|
| 135 |
+
)
|
| 136 |
+
# Create a tight bounding box from masks, useful when image is cropped
|
| 137 |
+
if self.crop_gen and instances.has("gt_masks"):
|
| 138 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
| 139 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
| 140 |
+
|
| 141 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
| 142 |
+
if "sem_seg_file_name" in dataset_dict:
|
| 143 |
+
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
|
| 144 |
+
sem_seg_gt = Image.open(f)
|
| 145 |
+
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
|
| 146 |
+
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
|
| 147 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
| 148 |
+
dataset_dict["sem_seg"] = sem_seg_gt
|
| 149 |
+
return dataset_dict
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
### Common Datasets
|
| 4 |
+
|
| 5 |
+
The dataset implemented here do not need to load the data into the final format.
|
| 6 |
+
It should provide the minimal data structure needed to use the dataset, so it can be very efficient.
|
| 7 |
+
|
| 8 |
+
For example, for an image dataset, just provide the file names and labels, but don't read the images.
|
| 9 |
+
Let the downstream decide how to read.
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .cityscapes import load_cityscapes_instances
|
| 3 |
+
from .coco import load_coco_json, load_sem_seg
|
| 4 |
+
from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta
|
| 5 |
+
from .register_coco import register_coco_instances, register_coco_panoptic_separated
|
| 6 |
+
from . import builtin # ensure the builtin data are registered
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/builtin.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
This file registers pre-defined data at hard-coded paths, and their metadata.
|
| 7 |
+
|
| 8 |
+
We hard-code metadata for common data. This will enable:
|
| 9 |
+
1. Consistency check when loading the data
|
| 10 |
+
2. Use models on these standard data directly and run demos,
|
| 11 |
+
without having to download the dataset annotations
|
| 12 |
+
|
| 13 |
+
We hard-code some paths to the dataset that's assumed to
|
| 14 |
+
exist in "./data/".
|
| 15 |
+
|
| 16 |
+
Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
|
| 17 |
+
To add new dataset, refer to the tutorial "docs/DATASETS.md".
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 23 |
+
|
| 24 |
+
from .builtin_meta import _get_builtin_metadata
|
| 25 |
+
from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic
|
| 26 |
+
from .lvis import get_lvis_instances_meta, register_lvis_instances
|
| 27 |
+
from .pascal_voc import register_pascal_voc
|
| 28 |
+
from .register_coco import register_coco_instances, register_coco_panoptic_separated
|
| 29 |
+
|
| 30 |
+
# ==== Predefined data and splits for COCO ==========
|
| 31 |
+
|
| 32 |
+
_PREDEFINED_SPLITS_COCO = {}
|
| 33 |
+
_PREDEFINED_SPLITS_COCO["coco"] = {
|
| 34 |
+
"coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"),
|
| 35 |
+
"coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
|
| 36 |
+
"coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"),
|
| 37 |
+
"coco_2014_minival_100": ("coco/val2014", "coco/annotations/instances_minival2014_100.json"),
|
| 38 |
+
"coco_2014_valminusminival": (
|
| 39 |
+
"coco/val2014",
|
| 40 |
+
"coco/annotations/instances_valminusminival2014.json",
|
| 41 |
+
),
|
| 42 |
+
"coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"),
|
| 43 |
+
"coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"),
|
| 44 |
+
"coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"),
|
| 45 |
+
"coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"),
|
| 46 |
+
"coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
_PREDEFINED_SPLITS_COCO["coco_person"] = {
|
| 50 |
+
"keypoints_coco_2014_train": (
|
| 51 |
+
"coco/train2014",
|
| 52 |
+
"coco/annotations/person_keypoints_train2014.json",
|
| 53 |
+
),
|
| 54 |
+
"keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"),
|
| 55 |
+
"keypoints_coco_2014_minival": (
|
| 56 |
+
"coco/val2014",
|
| 57 |
+
"coco/annotations/person_keypoints_minival2014.json",
|
| 58 |
+
),
|
| 59 |
+
"keypoints_coco_2014_valminusminival": (
|
| 60 |
+
"coco/val2014",
|
| 61 |
+
"coco/annotations/person_keypoints_valminusminival2014.json",
|
| 62 |
+
),
|
| 63 |
+
"keypoints_coco_2014_minival_100": (
|
| 64 |
+
"coco/val2014",
|
| 65 |
+
"coco/annotations/person_keypoints_minival2014_100.json",
|
| 66 |
+
),
|
| 67 |
+
"keypoints_coco_2017_train": (
|
| 68 |
+
"coco/train2017",
|
| 69 |
+
"coco/annotations/person_keypoints_train2017.json",
|
| 70 |
+
),
|
| 71 |
+
"keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"),
|
| 72 |
+
"keypoints_coco_2017_val_100": (
|
| 73 |
+
"coco/val2017",
|
| 74 |
+
"coco/annotations/person_keypoints_val2017_100.json",
|
| 75 |
+
),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_PREDEFINED_SPLITS_COCO_PANOPTIC = {
|
| 80 |
+
"coco_2017_train_panoptic": (
|
| 81 |
+
# This is the original panoptic annotation directory
|
| 82 |
+
"coco/panoptic_train2017",
|
| 83 |
+
"coco/annotations/panoptic_train2017.json",
|
| 84 |
+
# This directory contains semantic annotations that are
|
| 85 |
+
# converted from panoptic annotations.
|
| 86 |
+
# It is used by PanopticFPN.
|
| 87 |
+
# You can use the script at detectron2/data/prepare_panoptic_fpn.py
|
| 88 |
+
# to create these directories.
|
| 89 |
+
"coco/panoptic_stuff_train2017",
|
| 90 |
+
),
|
| 91 |
+
"coco_2017_val_panoptic": (
|
| 92 |
+
"coco/panoptic_val2017",
|
| 93 |
+
"coco/annotations/panoptic_val2017.json",
|
| 94 |
+
"coco/panoptic_stuff_val2017",
|
| 95 |
+
),
|
| 96 |
+
"coco_2017_val_100_panoptic": (
|
| 97 |
+
"coco/panoptic_val2017_100",
|
| 98 |
+
"coco/annotations/panoptic_val2017_100.json",
|
| 99 |
+
"coco/panoptic_stuff_val2017_100",
|
| 100 |
+
),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def register_all_coco(root):
|
| 105 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items():
|
| 106 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 107 |
+
# Assume pre-defined data live in `./data`.
|
| 108 |
+
register_coco_instances(
|
| 109 |
+
key,
|
| 110 |
+
_get_builtin_metadata(dataset_name),
|
| 111 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 112 |
+
os.path.join(root, image_root),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for (
|
| 116 |
+
prefix,
|
| 117 |
+
(panoptic_root, panoptic_json, semantic_root),
|
| 118 |
+
) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
|
| 119 |
+
prefix_instances = prefix[: -len("_panoptic")]
|
| 120 |
+
instances_meta = MetadataCatalog.get(prefix_instances)
|
| 121 |
+
image_root, instances_json = instances_meta.image_root, instances_meta.json_file
|
| 122 |
+
register_coco_panoptic_separated(
|
| 123 |
+
prefix,
|
| 124 |
+
_get_builtin_metadata("coco_panoptic_separated"),
|
| 125 |
+
image_root,
|
| 126 |
+
os.path.join(root, panoptic_root),
|
| 127 |
+
os.path.join(root, panoptic_json),
|
| 128 |
+
os.path.join(root, semantic_root),
|
| 129 |
+
instances_json,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ==== Predefined data and splits for LVIS ==========
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
_PREDEFINED_SPLITS_LVIS = {
|
| 137 |
+
"lvis_v0.5": {
|
| 138 |
+
"lvis_v0.5_train": ("coco/train2017", "lvis/lvis_v0.5_train.json"),
|
| 139 |
+
"lvis_v0.5_val": ("coco/val2017", "lvis/lvis_v0.5_val.json"),
|
| 140 |
+
"lvis_v0.5_val_rand_100": ("coco/val2017", "lvis/lvis_v0.5_val_rand_100.json"),
|
| 141 |
+
"lvis_v0.5_test": ("coco/test2017", "lvis/lvis_v0.5_image_info_test.json"),
|
| 142 |
+
},
|
| 143 |
+
"lvis_v0.5_cocofied": {
|
| 144 |
+
"lvis_v0.5_train_cocofied": ("coco/train2017", "lvis/lvis_v0.5_train_cocofied.json"),
|
| 145 |
+
"lvis_v0.5_val_cocofied": ("coco/val2017", "lvis/lvis_v0.5_val_cocofied.json"),
|
| 146 |
+
},
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def register_all_lvis(root):
|
| 151 |
+
for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
|
| 152 |
+
for key, (image_root, json_file) in splits_per_dataset.items():
|
| 153 |
+
# Assume pre-defined data live in `./data`.
|
| 154 |
+
register_lvis_instances(
|
| 155 |
+
key,
|
| 156 |
+
get_lvis_instances_meta(dataset_name),
|
| 157 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
| 158 |
+
os.path.join(root, image_root),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ==== Predefined splits for raw cityscapes images ===========
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
_RAW_CITYSCAPES_SPLITS = {
|
| 166 |
+
"cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train", "cityscapes/gtFine/train"),
|
| 167 |
+
"cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val", "cityscapes/gtFine/val"),
|
| 168 |
+
"cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test", "cityscapes/gtFine/test"),
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def register_all_cityscapes(root):
|
| 173 |
+
for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items():
|
| 174 |
+
meta = _get_builtin_metadata("cityscapes")
|
| 175 |
+
image_dir = os.path.join(root, image_dir)
|
| 176 |
+
gt_dir = os.path.join(root, gt_dir)
|
| 177 |
+
|
| 178 |
+
inst_key = key.format(task="instance_seg")
|
| 179 |
+
DatasetCatalog.register(
|
| 180 |
+
inst_key,
|
| 181 |
+
lambda x=image_dir, y=gt_dir: load_cityscapes_instances(
|
| 182 |
+
x, y, from_json=True, to_polygons=True
|
| 183 |
+
),
|
| 184 |
+
)
|
| 185 |
+
MetadataCatalog.get(inst_key).set(
|
| 186 |
+
image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
sem_key = key.format(task="sem_seg")
|
| 190 |
+
DatasetCatalog.register(
|
| 191 |
+
sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y)
|
| 192 |
+
)
|
| 193 |
+
MetadataCatalog.get(sem_key).set(
|
| 194 |
+
image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_sem_seg", **meta
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ==== Predefined splits for PASCAL VOC ===========
|
| 199 |
+
def register_all_pascal_voc(root):
|
| 200 |
+
SPLITS = [
|
| 201 |
+
("voc_2007_trainval", "VOC2007", "trainval"),
|
| 202 |
+
("voc_2007_train", "VOC2007", "train"),
|
| 203 |
+
("voc_2007_val", "VOC2007", "val"),
|
| 204 |
+
("voc_2007_test", "VOC2007", "test"),
|
| 205 |
+
("voc_2012_trainval", "VOC2012", "trainval"),
|
| 206 |
+
("voc_2012_train", "VOC2012", "train"),
|
| 207 |
+
("voc_2012_val", "VOC2012", "val"),
|
| 208 |
+
]
|
| 209 |
+
for name, dirname, split in SPLITS:
|
| 210 |
+
year = 2007 if "2007" in name else 2012
|
| 211 |
+
register_pascal_voc(name, os.path.join(root, dirname), split, year)
|
| 212 |
+
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Register them all under "./data"
|
| 216 |
+
_root = os.getenv("DETECTRON2_DATASETS", "data")
|
| 217 |
+
register_all_coco(_root)
|
| 218 |
+
register_all_lvis(_root)
|
| 219 |
+
register_all_cityscapes(_root)
|
| 220 |
+
register_all_pascal_voc(_root)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/builtin_meta.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# All coco categories, together with their nice-looking visualization colors
|
| 6 |
+
# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
|
| 7 |
+
COCO_CATEGORIES = [
|
| 8 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
| 9 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
| 10 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
| 11 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
| 12 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
| 13 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
| 14 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
| 15 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
| 16 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
| 17 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
| 18 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
| 19 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
| 20 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
| 21 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
| 22 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
| 23 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
| 24 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
| 25 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
| 26 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
| 27 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
| 28 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
| 29 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
| 30 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
| 31 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
| 32 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
| 33 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
| 34 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
| 35 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
| 36 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
| 37 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
| 38 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
| 39 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
| 40 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
| 41 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
| 42 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
| 43 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
| 44 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
| 45 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
| 46 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
| 47 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
| 48 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
| 49 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
| 50 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
| 51 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
| 52 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
| 53 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
| 54 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
| 55 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
| 56 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
| 57 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
| 58 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
| 59 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
| 60 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
| 61 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
| 62 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
| 63 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
| 64 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
| 65 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
| 66 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
| 67 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
| 68 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
| 69 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
| 70 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
| 71 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
| 72 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
| 73 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
| 74 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
| 75 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
| 76 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
| 77 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
| 78 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
| 79 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
| 80 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
| 81 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
| 82 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
| 83 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
| 84 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
| 85 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
| 86 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
| 87 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
| 88 |
+
{"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
|
| 89 |
+
{"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
|
| 90 |
+
{"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
|
| 91 |
+
{"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
|
| 92 |
+
{"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
|
| 93 |
+
{"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
|
| 94 |
+
{"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
|
| 95 |
+
{"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
|
| 96 |
+
{"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
|
| 97 |
+
{"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
|
| 98 |
+
{"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
|
| 99 |
+
{"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
|
| 100 |
+
{"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
|
| 101 |
+
{"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
|
| 102 |
+
{"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
|
| 103 |
+
{"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
|
| 104 |
+
{"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
|
| 105 |
+
{"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
|
| 106 |
+
{"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
|
| 107 |
+
{"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
|
| 108 |
+
{"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
|
| 109 |
+
{"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
|
| 110 |
+
{"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
|
| 111 |
+
{"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
|
| 112 |
+
{"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
|
| 113 |
+
{"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
|
| 114 |
+
{"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
|
| 115 |
+
{"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
|
| 116 |
+
{"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
|
| 117 |
+
{"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
|
| 118 |
+
{"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
|
| 119 |
+
{"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
|
| 120 |
+
{"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
|
| 121 |
+
{"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
|
| 122 |
+
{"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
|
| 123 |
+
{"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
|
| 124 |
+
{"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
|
| 125 |
+
{"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
|
| 126 |
+
{"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
|
| 127 |
+
{"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
|
| 128 |
+
{"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
|
| 129 |
+
{"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
|
| 130 |
+
{"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
|
| 131 |
+
{"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
|
| 132 |
+
{"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
|
| 133 |
+
{"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
|
| 134 |
+
{"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
|
| 135 |
+
{"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
|
| 136 |
+
{"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
|
| 137 |
+
{"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
|
| 138 |
+
{"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
|
| 139 |
+
{"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
|
| 140 |
+
{"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
# fmt: off
|
| 144 |
+
COCO_PERSON_KEYPOINT_NAMES = (
|
| 145 |
+
"nose",
|
| 146 |
+
"left_eye", "right_eye",
|
| 147 |
+
"left_ear", "right_ear",
|
| 148 |
+
"left_shoulder", "right_shoulder",
|
| 149 |
+
"left_elbow", "right_elbow",
|
| 150 |
+
"left_wrist", "right_wrist",
|
| 151 |
+
"left_hip", "right_hip",
|
| 152 |
+
"left_knee", "right_knee",
|
| 153 |
+
"left_ankle", "right_ankle",
|
| 154 |
+
)
|
| 155 |
+
# fmt: on
|
| 156 |
+
|
| 157 |
+
# Pairs of keypoints that should be exchanged under horizontal flipping
|
| 158 |
+
COCO_PERSON_KEYPOINT_FLIP_MAP = (
|
| 159 |
+
("left_eye", "right_eye"),
|
| 160 |
+
("left_ear", "right_ear"),
|
| 161 |
+
("left_shoulder", "right_shoulder"),
|
| 162 |
+
("left_elbow", "right_elbow"),
|
| 163 |
+
("left_wrist", "right_wrist"),
|
| 164 |
+
("left_hip", "right_hip"),
|
| 165 |
+
("left_knee", "right_knee"),
|
| 166 |
+
("left_ankle", "right_ankle"),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# rules for pairs of keypoints to draw a line between, and the line color to use.
|
| 170 |
+
KEYPOINT_CONNECTION_RULES = [
|
| 171 |
+
# face
|
| 172 |
+
("left_ear", "left_eye", (102, 204, 255)),
|
| 173 |
+
("right_ear", "right_eye", (51, 153, 255)),
|
| 174 |
+
("left_eye", "nose", (102, 0, 204)),
|
| 175 |
+
("nose", "right_eye", (51, 102, 255)),
|
| 176 |
+
# upper-body
|
| 177 |
+
("left_shoulder", "right_shoulder", (255, 128, 0)),
|
| 178 |
+
("left_shoulder", "left_elbow", (153, 255, 204)),
|
| 179 |
+
("right_shoulder", "right_elbow", (128, 229, 255)),
|
| 180 |
+
("left_elbow", "left_wrist", (153, 255, 153)),
|
| 181 |
+
("right_elbow", "right_wrist", (102, 255, 224)),
|
| 182 |
+
# lower-body
|
| 183 |
+
("left_hip", "right_hip", (255, 102, 0)),
|
| 184 |
+
("left_hip", "left_knee", (255, 255, 77)),
|
| 185 |
+
("right_hip", "right_knee", (153, 255, 204)),
|
| 186 |
+
("left_knee", "left_ankle", (191, 255, 128)),
|
| 187 |
+
("right_knee", "right_ankle", (255, 195, 77)),
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _get_coco_instances_meta():
|
| 192 |
+
thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 193 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 194 |
+
assert len(thing_ids) == 80, len(thing_ids)
|
| 195 |
+
# Mapping from the incontiguous COCO category id to an id in [0, 79]
|
| 196 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 197 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 198 |
+
ret = {
|
| 199 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 200 |
+
"thing_classes": thing_classes,
|
| 201 |
+
"thing_colors": thing_colors,
|
| 202 |
+
}
|
| 203 |
+
return ret
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _get_coco_panoptic_separated_meta():
|
| 207 |
+
"""
|
| 208 |
+
Returns metadata for "separated" version of the panoptic segmentation dataset.
|
| 209 |
+
"""
|
| 210 |
+
stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 211 |
+
assert len(stuff_ids) == 53, len(stuff_ids)
|
| 212 |
+
|
| 213 |
+
# For semantic segmentation, this mapping maps from contiguous stuff id
|
| 214 |
+
# (in [0, 53], used in models) to ids in the dataset (used for processing results)
|
| 215 |
+
# The id 0 is mapped to an extra category "thing".
|
| 216 |
+
stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
|
| 217 |
+
# When converting COCO panoptic annotations to semantic annotations
|
| 218 |
+
# We label the "thing" category to 0
|
| 219 |
+
stuff_dataset_id_to_contiguous_id[0] = 0
|
| 220 |
+
|
| 221 |
+
# 54 names for COCO stuff categories (including "things")
|
| 222 |
+
stuff_classes = ["things"] + [
|
| 223 |
+
k["name"].replace("-other", "").replace("-merged", "")
|
| 224 |
+
for k in COCO_CATEGORIES
|
| 225 |
+
if k["isthing"] == 0
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
# NOTE: I randomly picked a color for things
|
| 229 |
+
stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
|
| 230 |
+
ret = {
|
| 231 |
+
"stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| 232 |
+
"stuff_classes": stuff_classes,
|
| 233 |
+
"stuff_colors": stuff_colors,
|
| 234 |
+
}
|
| 235 |
+
ret.update(_get_coco_instances_meta())
|
| 236 |
+
return ret
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _get_builtin_metadata(dataset_name):
|
| 240 |
+
if dataset_name == "coco":
|
| 241 |
+
return _get_coco_instances_meta()
|
| 242 |
+
if dataset_name == "coco_panoptic_separated":
|
| 243 |
+
return _get_coco_panoptic_separated_meta()
|
| 244 |
+
elif dataset_name == "coco_person":
|
| 245 |
+
return {
|
| 246 |
+
"thing_classes": ["person"],
|
| 247 |
+
"keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
|
| 248 |
+
"keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
|
| 249 |
+
"keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
|
| 250 |
+
}
|
| 251 |
+
elif dataset_name == "cityscapes":
|
| 252 |
+
# fmt: off
|
| 253 |
+
CITYSCAPES_THING_CLASSES = [
|
| 254 |
+
"person", "rider", "car", "truck",
|
| 255 |
+
"bus", "train", "motorcycle", "bicycle",
|
| 256 |
+
]
|
| 257 |
+
CITYSCAPES_STUFF_CLASSES = [
|
| 258 |
+
"road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
|
| 259 |
+
"traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
|
| 260 |
+
"truck", "bus", "train", "motorcycle", "bicycle", "license plate",
|
| 261 |
+
]
|
| 262 |
+
# fmt: on
|
| 263 |
+
return {
|
| 264 |
+
"thing_classes": CITYSCAPES_THING_CLASSES,
|
| 265 |
+
"stuff_classes": CITYSCAPES_STUFF_CLASSES,
|
| 266 |
+
}
|
| 267 |
+
raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import functools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
from itertools import chain
|
| 9 |
+
import pycocotools.mask as mask_util
|
| 10 |
+
from fvcore.common.file_io import PathManager
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from detectron2.structures import BoxMode
|
| 14 |
+
from detectron2.utils.comm import get_world_size
|
| 15 |
+
from detectron2.utils.logger import setup_logger
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import cv2 # noqa
|
| 19 |
+
except ImportError:
|
| 20 |
+
# OpenCV is an optional dependency at the moment
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_cityscapes_files(image_dir, gt_dir):
|
| 28 |
+
files = []
|
| 29 |
+
# scan through the directory
|
| 30 |
+
cities = PathManager.ls(image_dir)
|
| 31 |
+
logger.info(f"{len(cities)} cities found in '{image_dir}'.")
|
| 32 |
+
for city in cities:
|
| 33 |
+
city_img_dir = os.path.join(image_dir, city)
|
| 34 |
+
city_gt_dir = os.path.join(gt_dir, city)
|
| 35 |
+
for basename in PathManager.ls(city_img_dir):
|
| 36 |
+
image_file = os.path.join(city_img_dir, basename)
|
| 37 |
+
|
| 38 |
+
suffix = "leftImg8bit.png"
|
| 39 |
+
assert basename.endswith(suffix)
|
| 40 |
+
basename = basename[: -len(suffix)]
|
| 41 |
+
|
| 42 |
+
instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png")
|
| 43 |
+
label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png")
|
| 44 |
+
json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json")
|
| 45 |
+
|
| 46 |
+
files.append((image_file, instance_file, label_file, json_file))
|
| 47 |
+
assert len(files), "No images found in {}".format(image_dir)
|
| 48 |
+
for f in files[0]:
|
| 49 |
+
assert PathManager.isfile(f), f
|
| 50 |
+
return files
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
| 57 |
+
gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
|
| 58 |
+
from_json (bool): whether to read annotations from the raw json file or the png files.
|
| 59 |
+
to_polygons (bool): whether to represent the segmentation as polygons
|
| 60 |
+
(COCO's format) instead of masks (cityscapes's format).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 64 |
+
`Using Custom Datasets </tutorials/data.html>`_ )
|
| 65 |
+
"""
|
| 66 |
+
if from_json:
|
| 67 |
+
assert to_polygons, (
|
| 68 |
+
"Cityscapes's json annotations are in polygon format. "
|
| 69 |
+
"Converting to mask format is not supported now."
|
| 70 |
+
)
|
| 71 |
+
files = get_cityscapes_files(image_dir, gt_dir)
|
| 72 |
+
|
| 73 |
+
logger.info("Preprocessing cityscapes annotations ...")
|
| 74 |
+
# This is still not fast: all workers will execute duplicate works and will
|
| 75 |
+
# take up to 10m on a 8GPU server.
|
| 76 |
+
pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))
|
| 77 |
+
|
| 78 |
+
ret = pool.map(
|
| 79 |
+
functools.partial(cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
|
| 80 |
+
files,
|
| 81 |
+
)
|
| 82 |
+
logger.info("Loaded {} images from {}".format(len(ret), image_dir))
|
| 83 |
+
|
| 84 |
+
# Map cityscape ids to contiguous ids
|
| 85 |
+
from cityscapesscripts.helpers.labels import labels
|
| 86 |
+
|
| 87 |
+
labels = [l for l in labels if l.hasInstances and not l.ignoreInEval]
|
| 88 |
+
dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)}
|
| 89 |
+
for dict_per_image in ret:
|
| 90 |
+
for anno in dict_per_image["annotations"]:
|
| 91 |
+
anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]]
|
| 92 |
+
return ret
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_cityscapes_semantic(image_dir, gt_dir):
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
| 99 |
+
gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
list[dict]: a list of dict, each has "file_name" and
|
| 103 |
+
"sem_seg_file_name".
|
| 104 |
+
"""
|
| 105 |
+
ret = []
|
| 106 |
+
# gt_dir is small and contain many small files. make sense to fetch to local first
|
| 107 |
+
gt_dir = PathManager.get_local_path(gt_dir)
|
| 108 |
+
for image_file, _, label_file, json_file in get_cityscapes_files(image_dir, gt_dir):
|
| 109 |
+
label_file = label_file.replace("labelIds", "labelTrainIds")
|
| 110 |
+
|
| 111 |
+
with PathManager.open(json_file, "r") as f:
|
| 112 |
+
jsonobj = json.load(f)
|
| 113 |
+
ret.append(
|
| 114 |
+
{
|
| 115 |
+
"file_name": image_file,
|
| 116 |
+
"sem_seg_file_name": label_file,
|
| 117 |
+
"height": jsonobj["imgHeight"],
|
| 118 |
+
"width": jsonobj["imgWidth"],
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
assert len(ret), f"No images found in {image_dir}!"
|
| 122 |
+
assert PathManager.isfile(
|
| 123 |
+
ret[0]["sem_seg_file_name"]
|
| 124 |
+
), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
|
| 125 |
+
return ret
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def cityscapes_files_to_dict(files, from_json, to_polygons):
|
| 129 |
+
"""
|
| 130 |
+
Parse cityscapes annotation files to a instance segmentation dataset dict.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file)
|
| 134 |
+
from_json (bool): whether to read annotations from the raw json file or the png files.
|
| 135 |
+
to_polygons (bool): whether to represent the segmentation as polygons
|
| 136 |
+
(COCO's format) instead of masks (cityscapes's format).
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
A dict in Detectron2 Dataset format.
|
| 140 |
+
"""
|
| 141 |
+
from cityscapesscripts.helpers.labels import id2label, name2label
|
| 142 |
+
|
| 143 |
+
image_file, instance_id_file, _, json_file = files
|
| 144 |
+
|
| 145 |
+
annos = []
|
| 146 |
+
|
| 147 |
+
if from_json:
|
| 148 |
+
from shapely.geometry import MultiPolygon, Polygon
|
| 149 |
+
|
| 150 |
+
with PathManager.open(json_file, "r") as f:
|
| 151 |
+
jsonobj = json.load(f)
|
| 152 |
+
ret = {
|
| 153 |
+
"file_name": image_file,
|
| 154 |
+
"image_id": os.path.basename(image_file),
|
| 155 |
+
"height": jsonobj["imgHeight"],
|
| 156 |
+
"width": jsonobj["imgWidth"],
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# `polygons_union` contains the union of all valid polygons.
|
| 160 |
+
polygons_union = Polygon()
|
| 161 |
+
|
| 162 |
+
# CityscapesScripts draw the polygons in sequential order
|
| 163 |
+
# and each polygon *overwrites* existing ones. See
|
| 164 |
+
# (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa
|
| 165 |
+
# We use reverse order, and each polygon *avoids* early ones.
|
| 166 |
+
# This will resolve the ploygon overlaps in the same way as CityscapesScripts.
|
| 167 |
+
for obj in jsonobj["objects"][::-1]:
|
| 168 |
+
if "deleted" in obj: # cityscapes data format specific
|
| 169 |
+
continue
|
| 170 |
+
label_name = obj["label"]
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
label = name2label[label_name]
|
| 174 |
+
except KeyError:
|
| 175 |
+
if label_name.endswith("group"): # crowd area
|
| 176 |
+
label = name2label[label_name[: -len("group")]]
|
| 177 |
+
else:
|
| 178 |
+
raise
|
| 179 |
+
if label.id < 0: # cityscapes data format
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
# Cityscapes's raw annotations uses integer coordinates
|
| 183 |
+
# Therefore +0.5 here
|
| 184 |
+
poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5
|
| 185 |
+
# CityscapesScript uses PIL.ImageDraw.polygon to rasterize
|
| 186 |
+
# polygons for evaluation. This function operates in integer space
|
| 187 |
+
# and draws each pixel whose center falls into the polygon.
|
| 188 |
+
# Therefore it draws a polygon which is 0.5 "fatter" in expectation.
|
| 189 |
+
# We therefore dilate the input polygon by 0.5 as our input.
|
| 190 |
+
poly = Polygon(poly_coord).buffer(0.5, resolution=4)
|
| 191 |
+
|
| 192 |
+
if not label.hasInstances or label.ignoreInEval:
|
| 193 |
+
# even if we won't store the polygon it still contributes to overlaps resolution
|
| 194 |
+
polygons_union = polygons_union.union(poly)
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
# Take non-overlapping part of the polygon
|
| 198 |
+
poly_wo_overlaps = poly.difference(polygons_union)
|
| 199 |
+
if poly_wo_overlaps.is_empty:
|
| 200 |
+
continue
|
| 201 |
+
polygons_union = polygons_union.union(poly)
|
| 202 |
+
|
| 203 |
+
anno = {}
|
| 204 |
+
anno["iscrowd"] = label_name.endswith("group")
|
| 205 |
+
anno["category_id"] = label.id
|
| 206 |
+
|
| 207 |
+
if isinstance(poly_wo_overlaps, Polygon):
|
| 208 |
+
poly_list = [poly_wo_overlaps]
|
| 209 |
+
elif isinstance(poly_wo_overlaps, MultiPolygon):
|
| 210 |
+
poly_list = poly_wo_overlaps.geoms
|
| 211 |
+
else:
|
| 212 |
+
raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))
|
| 213 |
+
|
| 214 |
+
poly_coord = []
|
| 215 |
+
for poly_el in poly_list:
|
| 216 |
+
# COCO API can work only with exterior boundaries now, hence we store only them.
|
| 217 |
+
# TODO: store both exterior and interior boundaries once other parts of the
|
| 218 |
+
# codebase support holes in polygons.
|
| 219 |
+
poly_coord.append(list(chain(*poly_el.exterior.coords)))
|
| 220 |
+
anno["segmentation"] = poly_coord
|
| 221 |
+
(xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds
|
| 222 |
+
|
| 223 |
+
anno["bbox"] = (xmin, ymin, xmax, ymax)
|
| 224 |
+
anno["bbox_mode"] = BoxMode.XYXY_ABS
|
| 225 |
+
|
| 226 |
+
annos.append(anno)
|
| 227 |
+
else:
|
| 228 |
+
# See also the official annotation parsing scripts at
|
| 229 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa
|
| 230 |
+
with PathManager.open(instance_id_file, "rb") as f:
|
| 231 |
+
inst_image = np.asarray(Image.open(f), order="F")
|
| 232 |
+
# ids < 24 are stuff labels (filtering them first is about 5% faster)
|
| 233 |
+
flattened_ids = np.unique(inst_image[inst_image >= 24])
|
| 234 |
+
|
| 235 |
+
ret = {
|
| 236 |
+
"file_name": image_file,
|
| 237 |
+
"image_id": os.path.basename(image_file),
|
| 238 |
+
"height": inst_image.shape[0],
|
| 239 |
+
"width": inst_image.shape[1],
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
for instance_id in flattened_ids:
|
| 243 |
+
# For non-crowd annotations, instance_id // 1000 is the label_id
|
| 244 |
+
# Crowd annotations have <1000 instance ids
|
| 245 |
+
label_id = instance_id // 1000 if instance_id >= 1000 else instance_id
|
| 246 |
+
label = id2label[label_id]
|
| 247 |
+
if not label.hasInstances or label.ignoreInEval:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
anno = {}
|
| 251 |
+
anno["iscrowd"] = instance_id < 1000
|
| 252 |
+
anno["category_id"] = label.id
|
| 253 |
+
|
| 254 |
+
mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F")
|
| 255 |
+
|
| 256 |
+
inds = np.nonzero(mask)
|
| 257 |
+
ymin, ymax = inds[0].min(), inds[0].max()
|
| 258 |
+
xmin, xmax = inds[1].min(), inds[1].max()
|
| 259 |
+
anno["bbox"] = (xmin, ymin, xmax, ymax)
|
| 260 |
+
if xmax <= xmin or ymax <= ymin:
|
| 261 |
+
continue
|
| 262 |
+
anno["bbox_mode"] = BoxMode.XYXY_ABS
|
| 263 |
+
if to_polygons:
|
| 264 |
+
# This conversion comes from D4809743 and D5171122,
|
| 265 |
+
# when Mask-RCNN was first developed.
|
| 266 |
+
contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
|
| 267 |
+
-2
|
| 268 |
+
]
|
| 269 |
+
polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
|
| 270 |
+
# opencv's can produce invalid polygons
|
| 271 |
+
if len(polygons) == 0:
|
| 272 |
+
continue
|
| 273 |
+
anno["segmentation"] = polygons
|
| 274 |
+
else:
|
| 275 |
+
anno["segmentation"] = mask_util.encode(mask[:, :, None])[0]
|
| 276 |
+
annos.append(anno)
|
| 277 |
+
ret["annotations"] = annos
|
| 278 |
+
return ret
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
"""
|
| 283 |
+
Test the cityscapes dataset loader.
|
| 284 |
+
|
| 285 |
+
Usage:
|
| 286 |
+
python -m detectron2.data.data.cityscapes \
|
| 287 |
+
cityscapes/leftImg8bit/train cityscapes/gtFine/train
|
| 288 |
+
"""
|
| 289 |
+
import argparse
|
| 290 |
+
|
| 291 |
+
parser = argparse.ArgumentParser()
|
| 292 |
+
parser.add_argument("image_dir")
|
| 293 |
+
parser.add_argument("gt_dir")
|
| 294 |
+
parser.add_argument("--type", choices=["instance", "semantic"], default="instance")
|
| 295 |
+
args = parser.parse_args()
|
| 296 |
+
from detectron2.data.catalog import Metadata
|
| 297 |
+
from detectron2.utils.visualizer import Visualizer
|
| 298 |
+
from cityscapesscripts.helpers.labels import labels
|
| 299 |
+
|
| 300 |
+
logger = setup_logger(name=__name__)
|
| 301 |
+
|
| 302 |
+
dirname = "cityscapes-data-vis"
|
| 303 |
+
os.makedirs(dirname, exist_ok=True)
|
| 304 |
+
|
| 305 |
+
if args.type == "instance":
|
| 306 |
+
dicts = load_cityscapes_instances(
|
| 307 |
+
args.image_dir, args.gt_dir, from_json=True, to_polygons=True
|
| 308 |
+
)
|
| 309 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 310 |
+
|
| 311 |
+
thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
|
| 312 |
+
meta = Metadata().set(thing_classes=thing_classes)
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir)
|
| 316 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 317 |
+
|
| 318 |
+
stuff_names = [k.name for k in labels if k.trainId != 255]
|
| 319 |
+
stuff_colors = [k.color for k in labels if k.trainId != 255]
|
| 320 |
+
meta = Metadata().set(stuff_names=stuff_names, stuff_colors=stuff_colors)
|
| 321 |
+
|
| 322 |
+
for d in dicts:
|
| 323 |
+
img = np.array(Image.open(PathManager.open(d["file_name"], "rb")))
|
| 324 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 325 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 326 |
+
# cv2.imshow("a", vis.get_image()[:, :, ::-1])
|
| 327 |
+
# cv2.waitKey()
|
| 328 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 329 |
+
vis.save(fpath)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/coco.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import contextlib
|
| 3 |
+
import datetime
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
import pycocotools.mask as mask_util
|
| 10 |
+
from fvcore.common.file_io import PathManager, file_lock
|
| 11 |
+
from fvcore.common.timer import Timer
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from detectron2.structures import Boxes, BoxMode, PolygonMasks
|
| 15 |
+
|
| 16 |
+
from .. import DatasetCatalog, MetadataCatalog
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
|
| 29 |
+
"""
|
| 30 |
+
Load a json file with COCO's instances annotation format.
|
| 31 |
+
Currently supports instance detection, instance segmentation,
|
| 32 |
+
and person keypoints annotations.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
json_file (str): full path to the json file in COCO instances annotation format.
|
| 36 |
+
image_root (str or path-like): the directory where the images in this json file exists.
|
| 37 |
+
dataset_name (str): the name of the dataset (e.g., coco_2017_train).
|
| 38 |
+
If provided, this function will also put "thing_classes" into
|
| 39 |
+
the metadata associated with this dataset.
|
| 40 |
+
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
|
| 41 |
+
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
|
| 42 |
+
"category_id", "segmentation"). The values for these keys will be returned as-is.
|
| 43 |
+
For example, the densepose annotations are loaded in this way.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
list[dict]: a list of dicts in Detectron2 standard dataset dicts format. (See
|
| 47 |
+
`Using Custom Datasets </tutorials/data.html>`_ )
|
| 48 |
+
|
| 49 |
+
Notes:
|
| 50 |
+
1. This function does not read the image files.
|
| 51 |
+
The results do not have the "image" field.
|
| 52 |
+
"""
|
| 53 |
+
from pycocotools.coco import COCO
|
| 54 |
+
|
| 55 |
+
timer = Timer()
|
| 56 |
+
json_file = PathManager.get_local_path(json_file)
|
| 57 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 58 |
+
coco_api = COCO(json_file)
|
| 59 |
+
if timer.seconds() > 1:
|
| 60 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 61 |
+
|
| 62 |
+
id_map = None
|
| 63 |
+
if dataset_name is not None:
|
| 64 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 65 |
+
cat_ids = sorted(coco_api.getCatIds())
|
| 66 |
+
cats = coco_api.loadCats(cat_ids)
|
| 67 |
+
# The categories in a custom json file may not be sorted.
|
| 68 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
| 69 |
+
meta.thing_classes = thing_classes
|
| 70 |
+
|
| 71 |
+
# In COCO, certain category ids are artificially removed,
|
| 72 |
+
# and by convention they are always ignored.
|
| 73 |
+
# We deal with COCO's id issue and translate
|
| 74 |
+
# the category ids to contiguous ids in [0, 80).
|
| 75 |
+
|
| 76 |
+
# It works by looking at the "categories" field in the json, therefore
|
| 77 |
+
# if users' own json also have incontiguous ids, we'll
|
| 78 |
+
# apply this mapping as well but print a warning.
|
| 79 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
| 80 |
+
if "coco" not in dataset_name:
|
| 81 |
+
logger.warning(
|
| 82 |
+
"""
|
| 83 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
| 84 |
+
"""
|
| 85 |
+
)
|
| 86 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
| 87 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
| 88 |
+
|
| 89 |
+
# sort indices for reproducible results
|
| 90 |
+
img_ids = sorted(coco_api.imgs.keys())
|
| 91 |
+
# imgs is a list of dicts, each looks something like:
|
| 92 |
+
# {'license': 4,
|
| 93 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 94 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 95 |
+
# 'height': 427,
|
| 96 |
+
# 'width': 640,
|
| 97 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 98 |
+
# 'id': 1268}
|
| 99 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 100 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 101 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 102 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
| 103 |
+
# [{'segmentation': [[192.81,
|
| 104 |
+
# 247.09,
|
| 105 |
+
# ...
|
| 106 |
+
# 219.03,
|
| 107 |
+
# 249.06]],
|
| 108 |
+
# 'area': 1035.749,
|
| 109 |
+
# 'iscrowd': 0,
|
| 110 |
+
# 'image_id': 1268,
|
| 111 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
| 112 |
+
# 'category_id': 16,
|
| 113 |
+
# 'id': 42986},
|
| 114 |
+
# ...]
|
| 115 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
| 116 |
+
|
| 117 |
+
if "minival" not in json_file:
|
| 118 |
+
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
|
| 119 |
+
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
|
| 120 |
+
# Therefore we explicitly white-list them.
|
| 121 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 122 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
| 123 |
+
json_file
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
imgs_anns = list(zip(imgs, anns))
|
| 127 |
+
|
| 128 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
|
| 129 |
+
|
| 130 |
+
dataset_dicts = []
|
| 131 |
+
|
| 132 |
+
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
|
| 133 |
+
|
| 134 |
+
num_instances_without_valid_segmentation = 0
|
| 135 |
+
|
| 136 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
| 137 |
+
record = {}
|
| 138 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
| 139 |
+
record["height"] = img_dict["height"]
|
| 140 |
+
record["width"] = img_dict["width"]
|
| 141 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 142 |
+
|
| 143 |
+
objs = []
|
| 144 |
+
for anno in anno_dict_list:
|
| 145 |
+
# Check that the image_id in this annotation is the same as
|
| 146 |
+
# the image_id we're looking at.
|
| 147 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 148 |
+
|
| 149 |
+
# The original COCO valminusminival2014 & minival2014 annotation files
|
| 150 |
+
# actually contains bugs that, together with certain ways of using COCO API,
|
| 151 |
+
# can trigger this assertion.
|
| 152 |
+
assert anno["image_id"] == image_id
|
| 153 |
+
|
| 154 |
+
assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
|
| 155 |
+
|
| 156 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
| 157 |
+
|
| 158 |
+
segm = anno.get("segmentation", None)
|
| 159 |
+
if segm: # either list[list[float]] or dict(RLE)
|
| 160 |
+
if not isinstance(segm, dict):
|
| 161 |
+
# filter out invalid polygons (< 3 points)
|
| 162 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 163 |
+
if len(segm) == 0:
|
| 164 |
+
num_instances_without_valid_segmentation += 1
|
| 165 |
+
continue # ignore this instance
|
| 166 |
+
obj["segmentation"] = segm
|
| 167 |
+
|
| 168 |
+
keypts = anno.get("keypoints", None)
|
| 169 |
+
if keypts: # list[int]
|
| 170 |
+
for idx, v in enumerate(keypts):
|
| 171 |
+
if idx % 3 != 2:
|
| 172 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 173 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 174 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
| 175 |
+
# add 0.5 to convert to floating point coordinates.
|
| 176 |
+
keypts[idx] = v + 0.5
|
| 177 |
+
obj["keypoints"] = keypts
|
| 178 |
+
|
| 179 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
| 180 |
+
if id_map:
|
| 181 |
+
obj["category_id"] = id_map[obj["category_id"]]
|
| 182 |
+
objs.append(obj)
|
| 183 |
+
record["annotations"] = objs
|
| 184 |
+
dataset_dicts.append(record)
|
| 185 |
+
|
| 186 |
+
if num_instances_without_valid_segmentation > 0:
|
| 187 |
+
logger.warning(
|
| 188 |
+
"Filtered out {} instances without valid segmentation. "
|
| 189 |
+
"There might be issues in your dataset generation process.".format(
|
| 190 |
+
num_instances_without_valid_segmentation
|
| 191 |
+
)
|
| 192 |
+
)
|
| 193 |
+
return dataset_dicts
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
|
| 197 |
+
"""
|
| 198 |
+
Load semantic segmentation data. All files under "gt_root" with "gt_ext" extension are
|
| 199 |
+
treated as ground truth annotations and all files under "image_root" with "image_ext" extension
|
| 200 |
+
as input images. Ground truth and input images are matched using file paths relative to
|
| 201 |
+
"gt_root" and "image_root" respectively without taking into account file extensions.
|
| 202 |
+
This works for COCO as well as some other data.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
|
| 206 |
+
annotations are stored as images with integer values in pixels that represent
|
| 207 |
+
corresponding semantic labels.
|
| 208 |
+
image_root (str): the directory where the input images are.
|
| 209 |
+
gt_ext (str): file extension for ground truth annotations.
|
| 210 |
+
image_ext (str): file extension for input images.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
list[dict]:
|
| 214 |
+
a list of dicts in detectron2 standard format without instance-level
|
| 215 |
+
annotation.
|
| 216 |
+
|
| 217 |
+
Notes:
|
| 218 |
+
1. This function does not read the image and ground truth files.
|
| 219 |
+
The results do not have the "image" and "sem_seg" fields.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
# We match input images with ground truth based on their relative filepaths (without file
|
| 223 |
+
# extensions) starting from 'image_root' and 'gt_root' respectively.
|
| 224 |
+
def file2id(folder_path, file_path):
|
| 225 |
+
# extract relative path starting from `folder_path`
|
| 226 |
+
image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
|
| 227 |
+
# remove file extension
|
| 228 |
+
image_id = os.path.splitext(image_id)[0]
|
| 229 |
+
return image_id
|
| 230 |
+
|
| 231 |
+
input_files = sorted(
|
| 232 |
+
(os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
|
| 233 |
+
key=lambda file_path: file2id(image_root, file_path),
|
| 234 |
+
)
|
| 235 |
+
gt_files = sorted(
|
| 236 |
+
(os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
|
| 237 |
+
key=lambda file_path: file2id(gt_root, file_path),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
|
| 241 |
+
|
| 242 |
+
# Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
|
| 243 |
+
if len(input_files) != len(gt_files):
|
| 244 |
+
logger.warn(
|
| 245 |
+
"Directory {} and {} has {} and {} files, respectively.".format(
|
| 246 |
+
image_root, gt_root, len(input_files), len(gt_files)
|
| 247 |
+
)
|
| 248 |
+
)
|
| 249 |
+
input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
|
| 250 |
+
gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
|
| 251 |
+
intersect = list(set(input_basenames) & set(gt_basenames))
|
| 252 |
+
# sort, otherwise each worker may obtain a list[dict] in different order
|
| 253 |
+
intersect = sorted(intersect)
|
| 254 |
+
logger.warn("Will use their intersection of {} files.".format(len(intersect)))
|
| 255 |
+
input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
|
| 256 |
+
gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
|
| 257 |
+
|
| 258 |
+
logger.info(
|
| 259 |
+
"Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
dataset_dicts = []
|
| 263 |
+
for (img_path, gt_path) in zip(input_files, gt_files):
|
| 264 |
+
record = {}
|
| 265 |
+
record["file_name"] = img_path
|
| 266 |
+
record["sem_seg_file_name"] = gt_path
|
| 267 |
+
dataset_dicts.append(record)
|
| 268 |
+
|
| 269 |
+
return dataset_dicts
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def convert_to_coco_dict(dataset_name):
|
| 273 |
+
"""
|
| 274 |
+
Convert an instance detection/segmentation or keypoint detection dataset
|
| 275 |
+
in detectron2's standard format into COCO json format.
|
| 276 |
+
|
| 277 |
+
Generic dataset description can be found here:
|
| 278 |
+
https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
|
| 279 |
+
|
| 280 |
+
COCO data format description can be found here:
|
| 281 |
+
http://cocodataset.org/#format-data
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
dataset_name (str):
|
| 285 |
+
name of the source dataset
|
| 286 |
+
Must be registered in DatastCatalog and in detectron2's standard format.
|
| 287 |
+
Must have corresponding metadata "thing_classes"
|
| 288 |
+
Returns:
|
| 289 |
+
coco_dict: serializable dict in COCO json format
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
dataset_dicts = DatasetCatalog.get(dataset_name)
|
| 293 |
+
metadata = MetadataCatalog.get(dataset_name)
|
| 294 |
+
|
| 295 |
+
# unmap the category mapping ids for COCO
|
| 296 |
+
if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
|
| 297 |
+
reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
|
| 298 |
+
reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
|
| 299 |
+
else:
|
| 300 |
+
reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
|
| 301 |
+
|
| 302 |
+
categories = [
|
| 303 |
+
{"id": reverse_id_mapper(id), "name": name}
|
| 304 |
+
for id, name in enumerate(metadata.thing_classes)
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
logger.info("Converting dataset dicts into COCO format")
|
| 308 |
+
coco_images = []
|
| 309 |
+
coco_annotations = []
|
| 310 |
+
|
| 311 |
+
for image_id, image_dict in enumerate(dataset_dicts):
|
| 312 |
+
coco_image = {
|
| 313 |
+
"id": image_dict.get("image_id", image_id),
|
| 314 |
+
"width": image_dict["width"],
|
| 315 |
+
"height": image_dict["height"],
|
| 316 |
+
"file_name": image_dict["file_name"],
|
| 317 |
+
}
|
| 318 |
+
coco_images.append(coco_image)
|
| 319 |
+
|
| 320 |
+
anns_per_image = image_dict["annotations"]
|
| 321 |
+
for annotation in anns_per_image:
|
| 322 |
+
# create a new dict with only COCO fields
|
| 323 |
+
coco_annotation = {}
|
| 324 |
+
|
| 325 |
+
# COCO requirement: XYWH box format
|
| 326 |
+
bbox = annotation["bbox"]
|
| 327 |
+
bbox_mode = annotation["bbox_mode"]
|
| 328 |
+
bbox = BoxMode.convert(bbox, bbox_mode, BoxMode.XYWH_ABS)
|
| 329 |
+
|
| 330 |
+
# COCO requirement: instance area
|
| 331 |
+
if "segmentation" in annotation:
|
| 332 |
+
# Computing areas for instances by counting the pixels
|
| 333 |
+
segmentation = annotation["segmentation"]
|
| 334 |
+
# TODO: check segmentation type: RLE, BinaryMask or Polygon
|
| 335 |
+
if isinstance(segmentation, list):
|
| 336 |
+
polygons = PolygonMasks([segmentation])
|
| 337 |
+
area = polygons.area()[0].item()
|
| 338 |
+
elif isinstance(segmentation, dict): # RLE
|
| 339 |
+
area = mask_util.area(segmentation).item()
|
| 340 |
+
else:
|
| 341 |
+
raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
|
| 342 |
+
else:
|
| 343 |
+
# Computing areas using bounding boxes
|
| 344 |
+
bbox_xy = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
| 345 |
+
area = Boxes([bbox_xy]).area()[0].item()
|
| 346 |
+
|
| 347 |
+
if "keypoints" in annotation:
|
| 348 |
+
keypoints = annotation["keypoints"] # list[int]
|
| 349 |
+
for idx, v in enumerate(keypoints):
|
| 350 |
+
if idx % 3 != 2:
|
| 351 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
| 352 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
| 353 |
+
# For COCO format consistency we substract 0.5
|
| 354 |
+
# https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
|
| 355 |
+
keypoints[idx] = v - 0.5
|
| 356 |
+
if "num_keypoints" in annotation:
|
| 357 |
+
num_keypoints = annotation["num_keypoints"]
|
| 358 |
+
else:
|
| 359 |
+
num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
|
| 360 |
+
|
| 361 |
+
# COCO requirement:
|
| 362 |
+
# linking annotations to images
|
| 363 |
+
# "id" field must start with 1
|
| 364 |
+
coco_annotation["id"] = len(coco_annotations) + 1
|
| 365 |
+
coco_annotation["image_id"] = coco_image["id"]
|
| 366 |
+
coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
|
| 367 |
+
coco_annotation["area"] = float(area)
|
| 368 |
+
coco_annotation["iscrowd"] = annotation.get("iscrowd", 0)
|
| 369 |
+
coco_annotation["category_id"] = reverse_id_mapper(annotation["category_id"])
|
| 370 |
+
|
| 371 |
+
# Add optional fields
|
| 372 |
+
if "keypoints" in annotation:
|
| 373 |
+
coco_annotation["keypoints"] = keypoints
|
| 374 |
+
coco_annotation["num_keypoints"] = num_keypoints
|
| 375 |
+
|
| 376 |
+
if "segmentation" in annotation:
|
| 377 |
+
coco_annotation["segmentation"] = annotation["segmentation"]
|
| 378 |
+
if isinstance(coco_annotation["segmentation"], dict): # RLE
|
| 379 |
+
coco_annotation["segmentation"]["counts"] = coco_annotation["segmentation"][
|
| 380 |
+
"counts"
|
| 381 |
+
].decode("ascii")
|
| 382 |
+
|
| 383 |
+
coco_annotations.append(coco_annotation)
|
| 384 |
+
|
| 385 |
+
logger.info(
|
| 386 |
+
"Conversion finished, "
|
| 387 |
+
f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
info = {
|
| 391 |
+
"date_created": str(datetime.datetime.now()),
|
| 392 |
+
"description": "Automatically generated COCO json file for Detectron2.",
|
| 393 |
+
}
|
| 394 |
+
coco_dict = {
|
| 395 |
+
"info": info,
|
| 396 |
+
"images": coco_images,
|
| 397 |
+
"annotations": coco_annotations,
|
| 398 |
+
"categories": categories,
|
| 399 |
+
"licenses": None,
|
| 400 |
+
}
|
| 401 |
+
return coco_dict
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
|
| 405 |
+
"""
|
| 406 |
+
Converts dataset into COCO format and saves it to a json file.
|
| 407 |
+
dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
dataset_name:
|
| 411 |
+
reference from the config file to the catalogs
|
| 412 |
+
must be registered in DatasetCatalog and in detectron2's standard format
|
| 413 |
+
output_file: path of json file that will be saved to
|
| 414 |
+
allow_cached: if json file is already present then skip conversion
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
# TODO: The dataset or the conversion script *may* change,
|
| 418 |
+
# a checksum would be useful for validating the cached data
|
| 419 |
+
|
| 420 |
+
PathManager.mkdirs(os.path.dirname(output_file))
|
| 421 |
+
with file_lock(output_file):
|
| 422 |
+
if PathManager.exists(output_file) and allow_cached:
|
| 423 |
+
logger.warning(
|
| 424 |
+
f"Using previously cached COCO format annotations at '{output_file}'. "
|
| 425 |
+
"You need to clear the cache file if your dataset has been modified."
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
|
| 429 |
+
coco_dict = convert_to_coco_dict(dataset_name)
|
| 430 |
+
|
| 431 |
+
logger.info(f"Caching COCO format annotations at '{output_file}' ...")
|
| 432 |
+
with PathManager.open(output_file, "w") as f:
|
| 433 |
+
json.dump(coco_dict, f)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
"""
|
| 438 |
+
Test the COCO json dataset loader.
|
| 439 |
+
|
| 440 |
+
Usage:
|
| 441 |
+
python -m detectron2.data.data.coco \
|
| 442 |
+
path/to/json path/to/image_root dataset_name
|
| 443 |
+
|
| 444 |
+
"dataset_name" can be "coco_2014_minival_100", or other
|
| 445 |
+
pre-registered ones
|
| 446 |
+
"""
|
| 447 |
+
from detectron2.utils.logger import setup_logger
|
| 448 |
+
from detectron2.utils.visualizer import Visualizer
|
| 449 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 450 |
+
import sys
|
| 451 |
+
|
| 452 |
+
logger = setup_logger(name=__name__)
|
| 453 |
+
assert sys.argv[3] in DatasetCatalog.list()
|
| 454 |
+
meta = MetadataCatalog.get(sys.argv[3])
|
| 455 |
+
|
| 456 |
+
dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 457 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 458 |
+
|
| 459 |
+
dirname = "coco-data-vis"
|
| 460 |
+
os.makedirs(dirname, exist_ok=True)
|
| 461 |
+
for d in dicts:
|
| 462 |
+
img = np.array(Image.open(d["file_name"]))
|
| 463 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 464 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 465 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 466 |
+
vis.save(fpath)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/lvis.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from fvcore.common.file_io import PathManager
|
| 5 |
+
from fvcore.common.timer import Timer
|
| 6 |
+
|
| 7 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 8 |
+
from detectron2.structures import BoxMode
|
| 9 |
+
|
| 10 |
+
from .builtin_meta import _get_coco_instances_meta
|
| 11 |
+
from .lvis_v0_5_categories import LVIS_CATEGORIES
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
This file contains functions to parse LVIS-format annotations into dicts in the
|
| 15 |
+
"Detectron2 format".
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def register_lvis_instances(name, metadata, json_file, image_root):
|
| 24 |
+
"""
|
| 25 |
+
Register a dataset in LVIS's json annotation format for instance detection and segmentation.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
|
| 29 |
+
metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
|
| 30 |
+
json_file (str): path to the json instance annotation file.
|
| 31 |
+
image_root (str or path-like): directory which contains all the images.
|
| 32 |
+
"""
|
| 33 |
+
DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name))
|
| 34 |
+
MetadataCatalog.get(name).set(
|
| 35 |
+
json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_lvis_json(json_file, image_root, dataset_name=None):
|
| 40 |
+
"""
|
| 41 |
+
Load a json file in LVIS's annotation format.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
json_file (str): full path to the LVIS json annotation file.
|
| 45 |
+
image_root (str): the directory where the images in this json file exists.
|
| 46 |
+
dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
|
| 47 |
+
If provided, this function will put "thing_classes" into the metadata
|
| 48 |
+
associated with this dataset.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
| 52 |
+
`Using Custom Datasets </tutorials/data.html>`_ )
|
| 53 |
+
|
| 54 |
+
Notes:
|
| 55 |
+
1. This function does not read the image files.
|
| 56 |
+
The results do not have the "image" field.
|
| 57 |
+
"""
|
| 58 |
+
from lvis import LVIS
|
| 59 |
+
|
| 60 |
+
json_file = PathManager.get_local_path(json_file)
|
| 61 |
+
|
| 62 |
+
timer = Timer()
|
| 63 |
+
lvis_api = LVIS(json_file)
|
| 64 |
+
if timer.seconds() > 1:
|
| 65 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 66 |
+
|
| 67 |
+
if dataset_name is not None:
|
| 68 |
+
meta = get_lvis_instances_meta(dataset_name)
|
| 69 |
+
MetadataCatalog.get(dataset_name).set(**meta)
|
| 70 |
+
|
| 71 |
+
# sort indices for reproducible results
|
| 72 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
| 73 |
+
# imgs is a list of dicts, each looks something like:
|
| 74 |
+
# {'license': 4,
|
| 75 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
| 76 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
| 77 |
+
# 'height': 427,
|
| 78 |
+
# 'width': 640,
|
| 79 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
| 80 |
+
# 'id': 1268}
|
| 81 |
+
imgs = lvis_api.load_imgs(img_ids)
|
| 82 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
| 83 |
+
# record for an object. The inner list enumerates the objects in an image
|
| 84 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
| 85 |
+
# [{'segmentation': [[192.81,
|
| 86 |
+
# 247.09,
|
| 87 |
+
# ...
|
| 88 |
+
# 219.03,
|
| 89 |
+
# 249.06]],
|
| 90 |
+
# 'area': 1035.749,
|
| 91 |
+
# 'image_id': 1268,
|
| 92 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
| 93 |
+
# 'category_id': 16,
|
| 94 |
+
# 'id': 42986},
|
| 95 |
+
# ...]
|
| 96 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
| 97 |
+
|
| 98 |
+
# Sanity check that each annotation has a unique id
|
| 99 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
| 100 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format(
|
| 101 |
+
json_file
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
imgs_anns = list(zip(imgs, anns))
|
| 105 |
+
|
| 106 |
+
logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file))
|
| 107 |
+
|
| 108 |
+
dataset_dicts = []
|
| 109 |
+
|
| 110 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
| 111 |
+
record = {}
|
| 112 |
+
file_name = img_dict["file_name"]
|
| 113 |
+
if img_dict["file_name"].startswith("COCO"):
|
| 114 |
+
# Convert form the COCO 2014 file naming convention of
|
| 115 |
+
# COCO_[train/val/test]2014_000000000000.jpg to the 2017 naming convention of
|
| 116 |
+
# 000000000000.jpg (LVIS v1 will fix this naming issue)
|
| 117 |
+
file_name = file_name[-16:]
|
| 118 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
| 119 |
+
record["height"] = img_dict["height"]
|
| 120 |
+
record["width"] = img_dict["width"]
|
| 121 |
+
record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
|
| 122 |
+
record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
|
| 123 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 124 |
+
|
| 125 |
+
objs = []
|
| 126 |
+
for anno in anno_dict_list:
|
| 127 |
+
# Check that the image_id in this annotation is the same as
|
| 128 |
+
# the image_id we're looking at.
|
| 129 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 130 |
+
assert anno["image_id"] == image_id
|
| 131 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
| 132 |
+
obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed
|
| 133 |
+
segm = anno["segmentation"] # list[list[float]]
|
| 134 |
+
# filter out invalid polygons (< 3 points)
|
| 135 |
+
valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
| 136 |
+
assert len(segm) == len(
|
| 137 |
+
valid_segm
|
| 138 |
+
), "Annotation contains an invalid polygon with < 3 points"
|
| 139 |
+
assert len(segm) > 0
|
| 140 |
+
obj["segmentation"] = segm
|
| 141 |
+
objs.append(obj)
|
| 142 |
+
record["annotations"] = objs
|
| 143 |
+
dataset_dicts.append(record)
|
| 144 |
+
|
| 145 |
+
return dataset_dicts
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_lvis_instances_meta(dataset_name):
|
| 149 |
+
"""
|
| 150 |
+
Load LVIS metadata.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5").
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
dict: LVIS metadata with keys: thing_classes
|
| 157 |
+
"""
|
| 158 |
+
if "cocofied" in dataset_name:
|
| 159 |
+
return _get_coco_instances_meta()
|
| 160 |
+
if "v0.5" in dataset_name:
|
| 161 |
+
return _get_lvis_instances_meta_v0_5()
|
| 162 |
+
# There will be a v1 in the future
|
| 163 |
+
# elif dataset_name == "lvis_v1":
|
| 164 |
+
# return get_lvis_instances_meta_v1()
|
| 165 |
+
raise ValueError("No built-in metadata for dataset {}".format(dataset_name))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _get_lvis_instances_meta_v0_5():
|
| 169 |
+
assert len(LVIS_CATEGORIES) == 1230
|
| 170 |
+
cat_ids = [k["id"] for k in LVIS_CATEGORIES]
|
| 171 |
+
assert min(cat_ids) == 1 and max(cat_ids) == len(
|
| 172 |
+
cat_ids
|
| 173 |
+
), "Category ids are not in [1, #categories], as expected"
|
| 174 |
+
# Ensure that the category list is sorted by id
|
| 175 |
+
lvis_categories = sorted(LVIS_CATEGORIES, key=lambda x: x["id"])
|
| 176 |
+
thing_classes = [k["synonyms"][0] for k in lvis_categories]
|
| 177 |
+
meta = {"thing_classes": thing_classes}
|
| 178 |
+
return meta
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
"""
|
| 183 |
+
Test the LVIS json dataset loader.
|
| 184 |
+
|
| 185 |
+
Usage:
|
| 186 |
+
python -m detectron2.data.data.lvis \
|
| 187 |
+
path/to/json path/to/image_root dataset_name vis_limit
|
| 188 |
+
"""
|
| 189 |
+
import sys
|
| 190 |
+
import numpy as np
|
| 191 |
+
from detectron2.utils.logger import setup_logger
|
| 192 |
+
from PIL import Image
|
| 193 |
+
import detectron2.data.datasets # noqa # add pre-defined metadata
|
| 194 |
+
from detectron2.utils.visualizer import Visualizer
|
| 195 |
+
|
| 196 |
+
logger = setup_logger(name=__name__)
|
| 197 |
+
meta = MetadataCatalog.get(sys.argv[3])
|
| 198 |
+
|
| 199 |
+
dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3])
|
| 200 |
+
logger.info("Done loading {} samples.".format(len(dicts)))
|
| 201 |
+
|
| 202 |
+
dirname = "lvis-data-vis"
|
| 203 |
+
os.makedirs(dirname, exist_ok=True)
|
| 204 |
+
for d in dicts[: int(sys.argv[4])]:
|
| 205 |
+
img = np.array(Image.open(d["file_name"]))
|
| 206 |
+
visualizer = Visualizer(img, metadata=meta)
|
| 207 |
+
vis = visualizer.draw_dataset_dict(d)
|
| 208 |
+
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
|
| 209 |
+
vis.save(fpath)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/lvis_v0_5_categories.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/pascal_voc.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import xml.etree.ElementTree as ET
|
| 7 |
+
from fvcore.common.file_io import PathManager
|
| 8 |
+
|
| 9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 10 |
+
from detectron2.structures import BoxMode
|
| 11 |
+
|
| 12 |
+
__all__ = ["register_pascal_voc"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# fmt: off
|
| 16 |
+
CLASS_NAMES = [
|
| 17 |
+
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
|
| 18 |
+
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
|
| 19 |
+
"pottedplant", "sheep", "sofa", "train", "tvmonitor",
|
| 20 |
+
]
|
| 21 |
+
# fmt: on
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_voc_instances(dirname: str, split: str):
|
| 25 |
+
"""
|
| 26 |
+
Load Pascal VOC detection annotations to Detectron2 format.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dirname: Contain "Annotations", "ImageSets", "JPEGImages"
|
| 30 |
+
split (str): one of "train", "test", "val", "trainval"
|
| 31 |
+
"""
|
| 32 |
+
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
|
| 33 |
+
fileids = np.loadtxt(f, dtype=np.str)
|
| 34 |
+
|
| 35 |
+
# Needs to read many small annotation files. Makes sense at local
|
| 36 |
+
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
|
| 37 |
+
dicts = []
|
| 38 |
+
for fileid in fileids:
|
| 39 |
+
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
|
| 40 |
+
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
|
| 41 |
+
|
| 42 |
+
with PathManager.open(anno_file) as f:
|
| 43 |
+
tree = ET.parse(f)
|
| 44 |
+
|
| 45 |
+
r = {
|
| 46 |
+
"file_name": jpeg_file,
|
| 47 |
+
"image_id": fileid,
|
| 48 |
+
"height": int(tree.findall("./size/height")[0].text),
|
| 49 |
+
"width": int(tree.findall("./size/width")[0].text),
|
| 50 |
+
}
|
| 51 |
+
instances = []
|
| 52 |
+
|
| 53 |
+
for obj in tree.findall("object"):
|
| 54 |
+
cls = obj.find("name").text
|
| 55 |
+
# We include "difficult" samples in training.
|
| 56 |
+
# Based on limited experiments, they don't hurt accuracy.
|
| 57 |
+
# difficult = int(obj.find("difficult").text)
|
| 58 |
+
# if difficult == 1:
|
| 59 |
+
# continue
|
| 60 |
+
bbox = obj.find("bndbox")
|
| 61 |
+
bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
|
| 62 |
+
# Original annotations are integers in the range [1, W or H]
|
| 63 |
+
# Assuming they mean 1-based pixel indices (inclusive),
|
| 64 |
+
# a box with annotation (xmin=1, xmax=W) covers the whole image.
|
| 65 |
+
# In coordinate space this is represented by (xmin=0, xmax=W)
|
| 66 |
+
bbox[0] -= 1.0
|
| 67 |
+
bbox[1] -= 1.0
|
| 68 |
+
instances.append(
|
| 69 |
+
{"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
|
| 70 |
+
)
|
| 71 |
+
r["annotations"] = instances
|
| 72 |
+
dicts.append(r)
|
| 73 |
+
return dicts
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def register_pascal_voc(name, dirname, split, year):
|
| 77 |
+
DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
|
| 78 |
+
MetadataCatalog.get(name).set(
|
| 79 |
+
thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
|
| 80 |
+
)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/datasets/register_coco.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 6 |
+
|
| 7 |
+
from .coco import load_coco_json, load_sem_seg
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
This file contains functions to register a COCO-format dataset to the DatasetCatalog.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
__all__ = ["register_coco_instances", "register_coco_panoptic_separated"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def register_coco_instances(name, metadata, json_file, image_root):
|
| 17 |
+
"""
|
| 18 |
+
Register a dataset in COCO's json annotation format for
|
| 19 |
+
instance detection, instance segmentation and keypoint detection.
|
| 20 |
+
(i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
|
| 21 |
+
`instances*.json` and `person_keypoints*.json` in the dataset).
|
| 22 |
+
|
| 23 |
+
This is an example of how to register a new dataset.
|
| 24 |
+
You can do something similar to this function, to register new data.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
|
| 28 |
+
metadata (dict): extra metadata associated with this dataset. You can
|
| 29 |
+
leave it as an empty dict.
|
| 30 |
+
json_file (str): path to the json instance annotation file.
|
| 31 |
+
image_root (str or path-like): directory which contains all the images.
|
| 32 |
+
"""
|
| 33 |
+
assert isinstance(name, str), name
|
| 34 |
+
assert isinstance(json_file, (str, os.PathLike)), json_file
|
| 35 |
+
assert isinstance(image_root, (str, os.PathLike)), image_root
|
| 36 |
+
# 1. register a function which returns dicts
|
| 37 |
+
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
|
| 38 |
+
|
| 39 |
+
# 2. Optionally, add metadata about this dataset,
|
| 40 |
+
# since they might be useful in evaluation, visualization or logging
|
| 41 |
+
MetadataCatalog.get(name).set(
|
| 42 |
+
json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def register_coco_panoptic_separated(
|
| 47 |
+
name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Register a COCO panoptic segmentation dataset named `name`.
|
| 51 |
+
The annotations in this registered dataset will contain both instance annotations and
|
| 52 |
+
semantic annotations, each with its own contiguous ids. Hence it's called "separated".
|
| 53 |
+
|
| 54 |
+
It follows the setting used by the PanopticFPN paper:
|
| 55 |
+
|
| 56 |
+
1. The instance annotations directly come from polygons in the COCO
|
| 57 |
+
instances annotation task, rather than from the masks in the COCO panoptic annotations.
|
| 58 |
+
|
| 59 |
+
The two format have small differences:
|
| 60 |
+
Polygons in the instance annotations may have overlaps.
|
| 61 |
+
The mask annotations are produced by labeling the overlapped polygons
|
| 62 |
+
with depth ordering.
|
| 63 |
+
|
| 64 |
+
2. The semantic annotations are converted from panoptic annotations, where
|
| 65 |
+
all "things" are assigned a semantic id of 0.
|
| 66 |
+
All semantic categories will therefore have ids in contiguous
|
| 67 |
+
range [1, #stuff_categories].
|
| 68 |
+
|
| 69 |
+
This function will also register a pure semantic segmentation dataset
|
| 70 |
+
named ``name + '_stuffonly'``.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
name (str): the name that identifies a dataset,
|
| 74 |
+
e.g. "coco_2017_train_panoptic"
|
| 75 |
+
metadata (dict): extra metadata associated with this dataset.
|
| 76 |
+
image_root (str): directory which contains all the images
|
| 77 |
+
panoptic_root (str): directory which contains panoptic annotation images
|
| 78 |
+
panoptic_json (str): path to the json panoptic annotation file
|
| 79 |
+
sem_seg_root (str): directory which contains all the ground truth segmentation annotations.
|
| 80 |
+
instances_json (str): path to the json instance annotation file
|
| 81 |
+
"""
|
| 82 |
+
panoptic_name = name + "_separated"
|
| 83 |
+
DatasetCatalog.register(
|
| 84 |
+
panoptic_name,
|
| 85 |
+
lambda: merge_to_panoptic(
|
| 86 |
+
load_coco_json(instances_json, image_root, panoptic_name),
|
| 87 |
+
load_sem_seg(sem_seg_root, image_root),
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
MetadataCatalog.get(panoptic_name).set(
|
| 91 |
+
panoptic_root=panoptic_root,
|
| 92 |
+
image_root=image_root,
|
| 93 |
+
panoptic_json=panoptic_json,
|
| 94 |
+
sem_seg_root=sem_seg_root,
|
| 95 |
+
json_file=instances_json, # TODO rename
|
| 96 |
+
evaluator_type="coco_panoptic_seg",
|
| 97 |
+
**metadata
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
semantic_name = name + "_stuffonly"
|
| 101 |
+
DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root))
|
| 102 |
+
MetadataCatalog.get(semantic_name).set(
|
| 103 |
+
sem_seg_root=sem_seg_root, image_root=image_root, evaluator_type="sem_seg", **metadata
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def merge_to_panoptic(detection_dicts, sem_seg_dicts):
|
| 108 |
+
"""
|
| 109 |
+
Create dataset dicts for panoptic segmentation, by
|
| 110 |
+
merging two dicts using "file_name" field to match their entries.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation.
|
| 114 |
+
sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in
|
| 118 |
+
both detection_dicts and sem_seg_dicts that correspond to the same image.
|
| 119 |
+
The function assumes that the same key in different dicts has the same value.
|
| 120 |
+
"""
|
| 121 |
+
results = []
|
| 122 |
+
sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts}
|
| 123 |
+
assert len(sem_seg_file_to_entry) > 0
|
| 124 |
+
|
| 125 |
+
for det_dict in detection_dicts:
|
| 126 |
+
dic = copy.copy(det_dict)
|
| 127 |
+
dic.update(sem_seg_file_to_entry[dic["file_name"]])
|
| 128 |
+
results.append(dic)
|
| 129 |
+
return results
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/detection_utils.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Common data processing utilities that are used in a
|
| 6 |
+
typical object detection data pipeline.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pycocotools.mask as mask_util
|
| 11 |
+
import torch
|
| 12 |
+
from fvcore.common.file_io import PathManager
|
| 13 |
+
from PIL import Image, ImageOps
|
| 14 |
+
|
| 15 |
+
from detectron2.structures import (
|
| 16 |
+
BitMasks,
|
| 17 |
+
Boxes,
|
| 18 |
+
BoxMode,
|
| 19 |
+
Instances,
|
| 20 |
+
Keypoints,
|
| 21 |
+
PolygonMasks,
|
| 22 |
+
RotatedBoxes,
|
| 23 |
+
polygons_to_bitmask,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from . import transforms as T
|
| 27 |
+
from .catalog import MetadataCatalog
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SizeMismatchError(ValueError):
|
| 31 |
+
"""
|
| 32 |
+
When loaded image has difference width/height compared with annotation.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
|
| 37 |
+
_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
|
| 38 |
+
_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def convert_PIL_to_numpy(image, format):
|
| 42 |
+
"""
|
| 43 |
+
Convert PIL image to numpy array of target format.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
image (PIL.Image): a PIL image
|
| 47 |
+
format (str): the format of output image
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
(np.ndarray): also see `read_image`
|
| 51 |
+
"""
|
| 52 |
+
if format is not None:
|
| 53 |
+
# PIL only supports RGB, so convert to RGB and flip channels over below
|
| 54 |
+
conversion_format = format
|
| 55 |
+
if format in ["BGR", "YUV-BT.601"]:
|
| 56 |
+
conversion_format = "RGB"
|
| 57 |
+
image = image.convert(conversion_format)
|
| 58 |
+
image = np.asarray(image)
|
| 59 |
+
# PIL squeezes out the channel dimension for "L", so make it HWC
|
| 60 |
+
if format == "L":
|
| 61 |
+
image = np.expand_dims(image, -1)
|
| 62 |
+
|
| 63 |
+
# handle formats not supported by PIL
|
| 64 |
+
elif format == "BGR":
|
| 65 |
+
# flip channels if needed
|
| 66 |
+
image = image[:, :, ::-1]
|
| 67 |
+
elif format == "YUV-BT.601":
|
| 68 |
+
image = image / 255.0
|
| 69 |
+
image = np.dot(image, np.array(_M_RGB2YUV).T)
|
| 70 |
+
|
| 71 |
+
return image
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def convert_image_to_rgb(image, format):
|
| 75 |
+
"""
|
| 76 |
+
Convert numpy image from given format to RGB.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
image (np.ndarray): a numpy image
|
| 80 |
+
format (str): the format of input image, also see `read_image`
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
(np.ndarray): HWC RGB image in 0-255 range, can be either float or uint8
|
| 84 |
+
"""
|
| 85 |
+
if format == "BGR":
|
| 86 |
+
image = image[:, :, [2, 1, 0]]
|
| 87 |
+
elif format == "YUV-BT.601":
|
| 88 |
+
image = np.dot(image, np.array(_M_YUV2RGB).T)
|
| 89 |
+
image = image * 255.0
|
| 90 |
+
else:
|
| 91 |
+
if format == "L":
|
| 92 |
+
image = image[:, :, 0]
|
| 93 |
+
image = image.astype(np.uint8)
|
| 94 |
+
image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
|
| 95 |
+
return image
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def read_image(file_name, format=None):
|
| 99 |
+
"""
|
| 100 |
+
Read an image into the given format.
|
| 101 |
+
Will apply rotation and flipping if the image has such exif information.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
file_name (str): image file path
|
| 105 |
+
format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601"
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
image (np.ndarray): an HWC image in the given format, which is 0-255, uint8 for
|
| 109 |
+
supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
|
| 110 |
+
"""
|
| 111 |
+
with PathManager.open(file_name, "rb") as f:
|
| 112 |
+
image = Image.open(f)
|
| 113 |
+
|
| 114 |
+
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
|
| 115 |
+
try:
|
| 116 |
+
image = ImageOps.exif_transpose(image)
|
| 117 |
+
except Exception:
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
return convert_PIL_to_numpy(image, format)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def check_image_size(dataset_dict, image):
|
| 124 |
+
"""
|
| 125 |
+
Raise an error if the image does not match the size specified in the dict.
|
| 126 |
+
"""
|
| 127 |
+
if "width" in dataset_dict or "height" in dataset_dict:
|
| 128 |
+
image_wh = (image.shape[1], image.shape[0])
|
| 129 |
+
expected_wh = (dataset_dict["width"], dataset_dict["height"])
|
| 130 |
+
if not image_wh == expected_wh:
|
| 131 |
+
raise SizeMismatchError(
|
| 132 |
+
"Mismatched (W,H){}, got {}, expect {}".format(
|
| 133 |
+
" for image " + dataset_dict["file_name"]
|
| 134 |
+
if "file_name" in dataset_dict
|
| 135 |
+
else "",
|
| 136 |
+
image_wh,
|
| 137 |
+
expected_wh,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# To ensure bbox always remap to original image size
|
| 142 |
+
if "width" not in dataset_dict:
|
| 143 |
+
dataset_dict["width"] = image.shape[1]
|
| 144 |
+
if "height" not in dataset_dict:
|
| 145 |
+
dataset_dict["height"] = image.shape[0]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def transform_proposals(dataset_dict, image_shape, transforms, min_box_side_len, proposal_topk):
|
| 149 |
+
"""
|
| 150 |
+
Apply transformations to the proposals in dataset_dict, if any.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
dataset_dict (dict): a dict read from the dataset, possibly
|
| 154 |
+
contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
|
| 155 |
+
image_shape (tuple): height, width
|
| 156 |
+
transforms (TransformList):
|
| 157 |
+
min_box_side_len (int): keep proposals with at least this size
|
| 158 |
+
proposal_topk (int): only keep top-K scoring proposals
|
| 159 |
+
|
| 160 |
+
The input dict is modified in-place, with abovementioned keys removed. A new
|
| 161 |
+
key "proposals" will be added. Its value is an `Instances`
|
| 162 |
+
object which contains the transformed proposals in its field
|
| 163 |
+
"proposal_boxes" and "objectness_logits".
|
| 164 |
+
"""
|
| 165 |
+
if "proposal_boxes" in dataset_dict:
|
| 166 |
+
# Transform proposal boxes
|
| 167 |
+
boxes = transforms.apply_box(
|
| 168 |
+
BoxMode.convert(
|
| 169 |
+
dataset_dict.pop("proposal_boxes"),
|
| 170 |
+
dataset_dict.pop("proposal_bbox_mode"),
|
| 171 |
+
BoxMode.XYXY_ABS,
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
boxes = Boxes(boxes)
|
| 175 |
+
objectness_logits = torch.as_tensor(
|
| 176 |
+
dataset_dict.pop("proposal_objectness_logits").astype("float32")
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
boxes.clip(image_shape)
|
| 180 |
+
keep = boxes.nonempty(threshold=min_box_side_len)
|
| 181 |
+
boxes = boxes[keep]
|
| 182 |
+
objectness_logits = objectness_logits[keep]
|
| 183 |
+
|
| 184 |
+
proposals = Instances(image_shape)
|
| 185 |
+
proposals.proposal_boxes = boxes[:proposal_topk]
|
| 186 |
+
proposals.objectness_logits = objectness_logits[:proposal_topk]
|
| 187 |
+
dataset_dict["proposals"] = proposals
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def transform_instance_annotations(
|
| 191 |
+
annotation, transforms, image_size, *, keypoint_hflip_indices=None
|
| 192 |
+
):
|
| 193 |
+
"""
|
| 194 |
+
Apply transforms to box, segmentation and keypoints annotations of a single instance.
|
| 195 |
+
|
| 196 |
+
It will use `transforms.apply_box` for the box, and
|
| 197 |
+
`transforms.apply_coords` for segmentation polygons & keypoints.
|
| 198 |
+
If you need anything more specially designed for each data structure,
|
| 199 |
+
you'll need to implement your own version of this function or the transforms.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
annotation (dict): dict of instance annotations for a single instance.
|
| 203 |
+
It will be modified in-place.
|
| 204 |
+
transforms (TransformList):
|
| 205 |
+
image_size (tuple): the height, width of the transformed image
|
| 206 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
dict:
|
| 210 |
+
the same input dict with fields "bbox", "segmentation", "keypoints"
|
| 211 |
+
transformed according to `transforms`.
|
| 212 |
+
The "bbox_mode" field will be set to XYXY_ABS.
|
| 213 |
+
"""
|
| 214 |
+
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
|
| 215 |
+
# Note that bbox is 1d (per-instance bounding box)
|
| 216 |
+
annotation["bbox"] = transforms.apply_box([bbox])[0]
|
| 217 |
+
annotation["bbox_mode"] = BoxMode.XYXY_ABS
|
| 218 |
+
|
| 219 |
+
if "segmentation" in annotation:
|
| 220 |
+
# each instance contains 1 or more polygons
|
| 221 |
+
segm = annotation["segmentation"]
|
| 222 |
+
if isinstance(segm, list):
|
| 223 |
+
# polygons
|
| 224 |
+
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
|
| 225 |
+
annotation["segmentation"] = [
|
| 226 |
+
p.reshape(-1) for p in transforms.apply_polygons(polygons)
|
| 227 |
+
]
|
| 228 |
+
elif isinstance(segm, dict):
|
| 229 |
+
# RLE
|
| 230 |
+
mask = mask_util.decode(segm)
|
| 231 |
+
mask = transforms.apply_segmentation(mask)
|
| 232 |
+
assert tuple(mask.shape[:2]) == image_size
|
| 233 |
+
annotation["segmentation"] = mask
|
| 234 |
+
else:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
"Cannot transform segmentation of type '{}'!"
|
| 237 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 238 |
+
" COCO-style RLE as a dict.".format(type(segm))
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if "keypoints" in annotation:
|
| 242 |
+
keypoints = transform_keypoint_annotations(
|
| 243 |
+
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
|
| 244 |
+
)
|
| 245 |
+
annotation["keypoints"] = keypoints
|
| 246 |
+
|
| 247 |
+
return annotation
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
|
| 251 |
+
"""
|
| 252 |
+
Transform keypoint annotations of an image.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
keypoints (list[float]): Nx3 float in Detectron2 Dataset format.
|
| 256 |
+
transforms (TransformList):
|
| 257 |
+
image_size (tuple): the height, width of the transformed image
|
| 258 |
+
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
|
| 259 |
+
"""
|
| 260 |
+
# (N*3,) -> (N, 3)
|
| 261 |
+
keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
|
| 262 |
+
keypoints[:, :2] = transforms.apply_coords(keypoints[:, :2])
|
| 263 |
+
|
| 264 |
+
# This assumes that HorizFlipTransform is the only one that does flip
|
| 265 |
+
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
|
| 266 |
+
|
| 267 |
+
# Alternative way: check if probe points was horizontally flipped.
|
| 268 |
+
# probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
|
| 269 |
+
# probe_aug = transforms.apply_coords(probe.copy())
|
| 270 |
+
# do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
|
| 271 |
+
|
| 272 |
+
# If flipped, swap each keypoint with its opposite-handed equivalent
|
| 273 |
+
if do_hflip:
|
| 274 |
+
assert keypoint_hflip_indices is not None
|
| 275 |
+
keypoints = keypoints[keypoint_hflip_indices, :]
|
| 276 |
+
|
| 277 |
+
# Maintain COCO convention that if visibility == 0, then x, y = 0
|
| 278 |
+
# TODO may need to reset visibility for cropped keypoints,
|
| 279 |
+
# but it does not matter for our existing algorithms
|
| 280 |
+
keypoints[keypoints[:, 2] == 0] = 0
|
| 281 |
+
return keypoints
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def annotations_to_instances(annos, image_size, mask_format="polygon"):
|
| 285 |
+
"""
|
| 286 |
+
Create an :class:`Instances` object used by the models,
|
| 287 |
+
from instance annotations in the dataset dict.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 291 |
+
element for one instance.
|
| 292 |
+
image_size (tuple): height, width
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Instances:
|
| 296 |
+
It will contain fields "gt_boxes", "gt_classes",
|
| 297 |
+
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
|
| 298 |
+
This is the format that builtin models expect.
|
| 299 |
+
"""
|
| 300 |
+
boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
|
| 301 |
+
target = Instances(image_size)
|
| 302 |
+
boxes = target.gt_boxes = Boxes(boxes)
|
| 303 |
+
boxes.clip(image_size)
|
| 304 |
+
|
| 305 |
+
classes = [obj["category_id"] for obj in annos]
|
| 306 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 307 |
+
target.gt_classes = classes
|
| 308 |
+
|
| 309 |
+
if len(annos) and "segmentation" in annos[0]:
|
| 310 |
+
segms = [obj["segmentation"] for obj in annos]
|
| 311 |
+
if mask_format == "polygon":
|
| 312 |
+
masks = PolygonMasks(segms)
|
| 313 |
+
else:
|
| 314 |
+
assert mask_format == "bitmask", mask_format
|
| 315 |
+
masks = []
|
| 316 |
+
for segm in segms:
|
| 317 |
+
if isinstance(segm, list):
|
| 318 |
+
# polygon
|
| 319 |
+
masks.append(polygons_to_bitmask(segm, *image_size))
|
| 320 |
+
elif isinstance(segm, dict):
|
| 321 |
+
# COCO RLE
|
| 322 |
+
masks.append(mask_util.decode(segm))
|
| 323 |
+
elif isinstance(segm, np.ndarray):
|
| 324 |
+
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
|
| 325 |
+
segm.ndim
|
| 326 |
+
)
|
| 327 |
+
# mask array
|
| 328 |
+
masks.append(segm)
|
| 329 |
+
else:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"Cannot convert segmentation of type '{}' to BitMasks!"
|
| 332 |
+
"Supported types are: polygons as list[list[float] or ndarray],"
|
| 333 |
+
" COCO-style RLE as a dict, or a full-image segmentation mask "
|
| 334 |
+
"as a 2D ndarray.".format(type(segm))
|
| 335 |
+
)
|
| 336 |
+
# torch.from_numpy does not support array with negative stride.
|
| 337 |
+
masks = BitMasks(
|
| 338 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
|
| 339 |
+
)
|
| 340 |
+
target.gt_masks = masks
|
| 341 |
+
|
| 342 |
+
if len(annos) and "keypoints" in annos[0]:
|
| 343 |
+
kpts = [obj.get("keypoints", []) for obj in annos]
|
| 344 |
+
target.gt_keypoints = Keypoints(kpts)
|
| 345 |
+
|
| 346 |
+
return target
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def annotations_to_instances_rotated(annos, image_size):
|
| 350 |
+
"""
|
| 351 |
+
Create an :class:`Instances` object used by the models,
|
| 352 |
+
from instance annotations in the dataset dict.
|
| 353 |
+
Compared to `annotations_to_instances`, this function is for rotated boxes only
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
annos (list[dict]): a list of instance annotations in one image, each
|
| 357 |
+
element for one instance.
|
| 358 |
+
image_size (tuple): height, width
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Instances:
|
| 362 |
+
Containing fields "gt_boxes", "gt_classes",
|
| 363 |
+
if they can be obtained from `annos`.
|
| 364 |
+
This is the format that builtin models expect.
|
| 365 |
+
"""
|
| 366 |
+
boxes = [obj["bbox"] for obj in annos]
|
| 367 |
+
target = Instances(image_size)
|
| 368 |
+
boxes = target.gt_boxes = RotatedBoxes(boxes)
|
| 369 |
+
boxes.clip(image_size)
|
| 370 |
+
|
| 371 |
+
classes = [obj["category_id"] for obj in annos]
|
| 372 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 373 |
+
target.gt_classes = classes
|
| 374 |
+
|
| 375 |
+
return target
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5):
|
| 379 |
+
"""
|
| 380 |
+
Filter out empty instances in an `Instances` object.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
instances (Instances):
|
| 384 |
+
by_box (bool): whether to filter out instances with empty boxes
|
| 385 |
+
by_mask (bool): whether to filter out instances with empty masks
|
| 386 |
+
box_threshold (float): minimum width and height to be considered non-empty
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
Instances: the filtered instances.
|
| 390 |
+
"""
|
| 391 |
+
assert by_box or by_mask
|
| 392 |
+
r = []
|
| 393 |
+
if by_box:
|
| 394 |
+
r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
|
| 395 |
+
if instances.has("gt_masks") and by_mask:
|
| 396 |
+
r.append(instances.gt_masks.nonempty())
|
| 397 |
+
|
| 398 |
+
# TODO: can also filter visible keypoints
|
| 399 |
+
|
| 400 |
+
if not r:
|
| 401 |
+
return instances
|
| 402 |
+
m = r[0]
|
| 403 |
+
for x in r[1:]:
|
| 404 |
+
m = m & x
|
| 405 |
+
return instances[m]
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def create_keypoint_hflip_indices(dataset_names):
|
| 409 |
+
"""
|
| 410 |
+
Args:
|
| 411 |
+
dataset_names (list[str]): list of dataset names
|
| 412 |
+
Returns:
|
| 413 |
+
ndarray[int]: a vector of size=#keypoints, storing the
|
| 414 |
+
horizontally-flipped keypoint indices.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
check_metadata_consistency("keypoint_names", dataset_names)
|
| 418 |
+
check_metadata_consistency("keypoint_flip_map", dataset_names)
|
| 419 |
+
|
| 420 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
| 421 |
+
names = meta.keypoint_names
|
| 422 |
+
# TODO flip -> hflip
|
| 423 |
+
flip_map = dict(meta.keypoint_flip_map)
|
| 424 |
+
flip_map.update({v: k for k, v in flip_map.items()})
|
| 425 |
+
flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
|
| 426 |
+
flip_indices = [names.index(i) for i in flipped_names]
|
| 427 |
+
return np.asarray(flip_indices)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def gen_crop_transform_with_instance(crop_size, image_size, instance):
|
| 431 |
+
"""
|
| 432 |
+
Generate a CropTransform so that the cropping region contains
|
| 433 |
+
the center of the given instance.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
crop_size (tuple): h, w in pixels
|
| 437 |
+
image_size (tuple): h, w
|
| 438 |
+
instance (dict): an annotation dict of one instance, in Detectron2's
|
| 439 |
+
dataset format.
|
| 440 |
+
"""
|
| 441 |
+
crop_size = np.asarray(crop_size, dtype=np.int32)
|
| 442 |
+
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
|
| 443 |
+
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
|
| 444 |
+
assert (
|
| 445 |
+
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
|
| 446 |
+
), "The annotation bounding box is outside of the image!"
|
| 447 |
+
assert (
|
| 448 |
+
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
|
| 449 |
+
), "Crop size is larger than image size!"
|
| 450 |
+
|
| 451 |
+
min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
|
| 452 |
+
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
|
| 453 |
+
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
|
| 454 |
+
|
| 455 |
+
y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
|
| 456 |
+
x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
|
| 457 |
+
return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def check_metadata_consistency(key, dataset_names):
|
| 461 |
+
"""
|
| 462 |
+
Check that the data have consistent metadata.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
key (str): a metadata key
|
| 466 |
+
dataset_names (list[str]): a list of dataset names
|
| 467 |
+
|
| 468 |
+
Raises:
|
| 469 |
+
AttributeError: if the key does not exist in the metadata
|
| 470 |
+
ValueError: if the given data do not have the same metadata values defined by key
|
| 471 |
+
"""
|
| 472 |
+
if len(dataset_names) == 0:
|
| 473 |
+
return
|
| 474 |
+
logger = logging.getLogger(__name__)
|
| 475 |
+
entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
|
| 476 |
+
for idx, entry in enumerate(entries_per_dataset):
|
| 477 |
+
if entry != entries_per_dataset[0]:
|
| 478 |
+
logger.error(
|
| 479 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
|
| 480 |
+
)
|
| 481 |
+
logger.error(
|
| 482 |
+
"Metadata '{}' for dataset '{}' is '{}'".format(
|
| 483 |
+
key, dataset_names[0], str(entries_per_dataset[0])
|
| 484 |
+
)
|
| 485 |
+
)
|
| 486 |
+
raise ValueError("Datasets have different metadata '{}'!".format(key))
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def build_transform_gen(cfg, is_train):
|
| 490 |
+
"""
|
| 491 |
+
Create a list of :class:`TransformGen` from config.
|
| 492 |
+
Now it includes resizing and flipping.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
list[TransformGen]
|
| 496 |
+
"""
|
| 497 |
+
if is_train:
|
| 498 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
| 499 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
| 500 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
| 501 |
+
else:
|
| 502 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
| 503 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
| 504 |
+
sample_style = "choice"
|
| 505 |
+
if sample_style == "range":
|
| 506 |
+
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
|
| 507 |
+
len(min_size)
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
logger = logging.getLogger(__name__)
|
| 511 |
+
tfm_gens = []
|
| 512 |
+
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
| 513 |
+
if is_train:
|
| 514 |
+
tfm_gens.append(T.RandomFlip())
|
| 515 |
+
logger.info("TransformGens used in training: " + str(tfm_gens))
|
| 516 |
+
return tfm_gens
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .distributed_sampler import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
|
| 3 |
+
from .grouped_batch_sampler import GroupedBatchSampler
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"GroupedBatchSampler",
|
| 7 |
+
"TrainingSampler",
|
| 8 |
+
"InferenceSampler",
|
| 9 |
+
"RepeatFactorTrainingSampler",
|
| 10 |
+
]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/distributed_sampler.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import itertools
|
| 3 |
+
import math
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data.sampler import Sampler
|
| 8 |
+
|
| 9 |
+
from detectron2.utils import comm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TrainingSampler(Sampler):
|
| 13 |
+
"""
|
| 14 |
+
In training, we only care about the "infinite stream" of training data.
|
| 15 |
+
So this sampler produces an infinite stream of indices and
|
| 16 |
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
| 17 |
+
|
| 18 |
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
| 19 |
+
where `indices` is an infinite stream of indices consisting of
|
| 20 |
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
| 21 |
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
size (int): the total number of data of the underlying dataset to sample from
|
| 28 |
+
shuffle (bool): whether to shuffle the indices or not
|
| 29 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
| 30 |
+
across all workers. If None, will use a random seed shared
|
| 31 |
+
among workers (require synchronization among all workers).
|
| 32 |
+
"""
|
| 33 |
+
self._size = size
|
| 34 |
+
assert size > 0
|
| 35 |
+
self._shuffle = shuffle
|
| 36 |
+
if seed is None:
|
| 37 |
+
seed = comm.shared_random_seed()
|
| 38 |
+
self._seed = int(seed)
|
| 39 |
+
|
| 40 |
+
self._rank = comm.get_rank()
|
| 41 |
+
self._world_size = comm.get_world_size()
|
| 42 |
+
|
| 43 |
+
def __iter__(self):
|
| 44 |
+
start = self._rank
|
| 45 |
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
| 46 |
+
|
| 47 |
+
def _infinite_indices(self):
|
| 48 |
+
g = torch.Generator()
|
| 49 |
+
g.manual_seed(self._seed)
|
| 50 |
+
while True:
|
| 51 |
+
if self._shuffle:
|
| 52 |
+
yield from torch.randperm(self._size, generator=g)
|
| 53 |
+
else:
|
| 54 |
+
yield from torch.arange(self._size)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RepeatFactorTrainingSampler(Sampler):
|
| 58 |
+
"""
|
| 59 |
+
Similar to TrainingSampler, but suitable for training on class imbalanced data
|
| 60 |
+
like LVIS. In each epoch, an image may appear multiple times based on its "repeat
|
| 61 |
+
factor". The repeat factor for an image is a function of the frequency the rarest
|
| 62 |
+
category labeled in that image. The "frequency of category c" in [0, 1] is defined
|
| 63 |
+
as the fraction of images in the training set (without repeats) in which category c
|
| 64 |
+
appears.
|
| 65 |
+
|
| 66 |
+
See :paper:`lvis` (>= v2) Appendix B.2.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, dataset_dicts, repeat_thresh, shuffle=True, seed=None):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
dataset_dicts (list[dict]): annotations in Detectron2 dataset format.
|
| 73 |
+
repeat_thresh (float): frequency threshold below which data is repeated.
|
| 74 |
+
shuffle (bool): whether to shuffle the indices or not
|
| 75 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
| 76 |
+
across all workers. If None, will use a random seed shared
|
| 77 |
+
among workers (require synchronization among all workers).
|
| 78 |
+
"""
|
| 79 |
+
self._shuffle = shuffle
|
| 80 |
+
if seed is None:
|
| 81 |
+
seed = comm.shared_random_seed()
|
| 82 |
+
self._seed = int(seed)
|
| 83 |
+
|
| 84 |
+
self._rank = comm.get_rank()
|
| 85 |
+
self._world_size = comm.get_world_size()
|
| 86 |
+
|
| 87 |
+
# Get fractional repeat factors and split into whole number (_int_part)
|
| 88 |
+
# and fractional (_frac_part) parts.
|
| 89 |
+
rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh)
|
| 90 |
+
self._int_part = torch.trunc(rep_factors)
|
| 91 |
+
self._frac_part = rep_factors - self._int_part
|
| 92 |
+
|
| 93 |
+
def _get_repeat_factors(self, dataset_dicts, repeat_thresh):
|
| 94 |
+
"""
|
| 95 |
+
Compute (fractional) per-image repeat factors.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
See __init__.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
torch.Tensor: the i-th element is the repeat factor for the dataset image
|
| 102 |
+
at index i.
|
| 103 |
+
"""
|
| 104 |
+
# 1. For each category c, compute the fraction of images that contain it: f(c)
|
| 105 |
+
category_freq = defaultdict(int)
|
| 106 |
+
for dataset_dict in dataset_dicts: # For each image (without repeats)
|
| 107 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 108 |
+
for cat_id in cat_ids:
|
| 109 |
+
category_freq[cat_id] += 1
|
| 110 |
+
num_images = len(dataset_dicts)
|
| 111 |
+
for k, v in category_freq.items():
|
| 112 |
+
category_freq[k] = v / num_images
|
| 113 |
+
|
| 114 |
+
# 2. For each category c, compute the category-level repeat factor:
|
| 115 |
+
# r(c) = max(1, sqrt(t / f(c)))
|
| 116 |
+
category_rep = {
|
| 117 |
+
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
|
| 118 |
+
for cat_id, cat_freq in category_freq.items()
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# 3. For each image I, compute the image-level repeat factor:
|
| 122 |
+
# r(I) = max_{c in I} r(c)
|
| 123 |
+
rep_factors = []
|
| 124 |
+
for dataset_dict in dataset_dicts:
|
| 125 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 126 |
+
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids})
|
| 127 |
+
rep_factors.append(rep_factor)
|
| 128 |
+
|
| 129 |
+
return torch.tensor(rep_factors, dtype=torch.float32)
|
| 130 |
+
|
| 131 |
+
def _get_epoch_indices(self, generator):
|
| 132 |
+
"""
|
| 133 |
+
Create a list of dataset indices (with repeats) to use for one epoch.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
generator (torch.Generator): pseudo random number generator used for
|
| 137 |
+
stochastic rounding.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
| 141 |
+
is repeated based on its calculated repeat factor.
|
| 142 |
+
"""
|
| 143 |
+
# Since repeat factors are fractional, we use stochastic rounding so
|
| 144 |
+
# that the target repeat factor is achieved in expectation over the
|
| 145 |
+
# course of training
|
| 146 |
+
rands = torch.rand(len(self._frac_part), generator=generator)
|
| 147 |
+
rep_factors = self._int_part + (rands < self._frac_part).float()
|
| 148 |
+
# Construct a list of indices in which we repeat images as specified
|
| 149 |
+
indices = []
|
| 150 |
+
for dataset_index, rep_factor in enumerate(rep_factors):
|
| 151 |
+
indices.extend([dataset_index] * int(rep_factor.item()))
|
| 152 |
+
return torch.tensor(indices, dtype=torch.int64)
|
| 153 |
+
|
| 154 |
+
def __iter__(self):
|
| 155 |
+
start = self._rank
|
| 156 |
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
| 157 |
+
|
| 158 |
+
def _infinite_indices(self):
|
| 159 |
+
g = torch.Generator()
|
| 160 |
+
g.manual_seed(self._seed)
|
| 161 |
+
while True:
|
| 162 |
+
# Sample indices with repeats determined by stochastic rounding; each
|
| 163 |
+
# "epoch" may have a slightly different size due to the rounding.
|
| 164 |
+
indices = self._get_epoch_indices(g)
|
| 165 |
+
if self._shuffle:
|
| 166 |
+
randperm = torch.randperm(len(indices), generator=g)
|
| 167 |
+
yield from indices[randperm]
|
| 168 |
+
else:
|
| 169 |
+
yield from indices
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class InferenceSampler(Sampler):
|
| 173 |
+
"""
|
| 174 |
+
Produce indices for inference.
|
| 175 |
+
Inference needs to run on the __exact__ set of samples,
|
| 176 |
+
therefore when the total number of samples is not divisible by the number of workers,
|
| 177 |
+
this sampler produces different number of samples on different workers.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, size: int):
|
| 181 |
+
"""
|
| 182 |
+
Args:
|
| 183 |
+
size (int): the total number of data of the underlying dataset to sample from
|
| 184 |
+
"""
|
| 185 |
+
self._size = size
|
| 186 |
+
assert size > 0
|
| 187 |
+
self._rank = comm.get_rank()
|
| 188 |
+
self._world_size = comm.get_world_size()
|
| 189 |
+
|
| 190 |
+
shard_size = (self._size - 1) // self._world_size + 1
|
| 191 |
+
begin = shard_size * self._rank
|
| 192 |
+
end = min(shard_size * (self._rank + 1), self._size)
|
| 193 |
+
self._local_indices = range(begin, end)
|
| 194 |
+
|
| 195 |
+
def __iter__(self):
|
| 196 |
+
yield from self._local_indices
|
| 197 |
+
|
| 198 |
+
def __len__(self):
|
| 199 |
+
return len(self._local_indices)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/samplers/grouped_batch_sampler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GroupedBatchSampler(BatchSampler):
|
| 7 |
+
"""
|
| 8 |
+
Wraps another sampler to yield a mini-batch of indices.
|
| 9 |
+
It enforces that the batch only contain elements from the same group.
|
| 10 |
+
It also tries to provide mini-batches which follows an ordering which is
|
| 11 |
+
as close as possible to the ordering from the original sampler.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, sampler, group_ids, batch_size):
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
sampler (Sampler): Base sampler.
|
| 18 |
+
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
| 19 |
+
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
| 20 |
+
The group ids must be a set of integers in the range [0, num_groups).
|
| 21 |
+
batch_size (int): Size of mini-batch.
|
| 22 |
+
"""
|
| 23 |
+
if not isinstance(sampler, Sampler):
|
| 24 |
+
raise ValueError(
|
| 25 |
+
"sampler should be an instance of "
|
| 26 |
+
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
| 27 |
+
)
|
| 28 |
+
self.sampler = sampler
|
| 29 |
+
self.group_ids = np.asarray(group_ids)
|
| 30 |
+
assert self.group_ids.ndim == 1
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
groups = np.unique(self.group_ids).tolist()
|
| 33 |
+
|
| 34 |
+
# buffer the indices of each group until batch size is reached
|
| 35 |
+
self.buffer_per_group = {k: [] for k in groups}
|
| 36 |
+
|
| 37 |
+
def __iter__(self):
|
| 38 |
+
for idx in self.sampler:
|
| 39 |
+
group_id = self.group_ids[idx]
|
| 40 |
+
group_buffer = self.buffer_per_group[group_id]
|
| 41 |
+
group_buffer.append(idx)
|
| 42 |
+
if len(group_buffer) == self.batch_size:
|
| 43 |
+
yield group_buffer[:] # yield a copy of the list
|
| 44 |
+
del group_buffer[:]
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .transform import *
|
| 3 |
+
from fvcore.transforms.transform import *
|
| 4 |
+
from .transform_gen import *
|
| 5 |
+
|
| 6 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/transform.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
# File: transform.py
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from fvcore.transforms.transform import HFlipTransform, NoOpTransform, Transform
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import cv2 # noqa
|
| 13 |
+
except ImportError:
|
| 14 |
+
# OpenCV is an optional dependency at the moment
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
__all__ = ["ExtentTransform", "ResizeTransform", "RotationTransform"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ExtentTransform(Transform):
|
| 21 |
+
"""
|
| 22 |
+
Extracts a subregion from the source image and scales it to the output size.
|
| 23 |
+
|
| 24 |
+
The fill color is used to map pixels from the source rect that fall outside
|
| 25 |
+
the source image.
|
| 26 |
+
|
| 27 |
+
See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, src_rect, output_size, interp=Image.LINEAR, fill=0):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
src_rect (x0, y0, x1, y1): src coordinates
|
| 34 |
+
output_size (h, w): dst image size
|
| 35 |
+
interp: PIL interpolation methods
|
| 36 |
+
fill: Fill color used when src_rect extends outside image
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self._set_attributes(locals())
|
| 40 |
+
|
| 41 |
+
def apply_image(self, img, interp=None):
|
| 42 |
+
h, w = self.output_size
|
| 43 |
+
ret = Image.fromarray(img).transform(
|
| 44 |
+
size=(w, h),
|
| 45 |
+
method=Image.EXTENT,
|
| 46 |
+
data=self.src_rect,
|
| 47 |
+
resample=interp if interp else self.interp,
|
| 48 |
+
fill=self.fill,
|
| 49 |
+
)
|
| 50 |
+
return np.asarray(ret)
|
| 51 |
+
|
| 52 |
+
def apply_coords(self, coords):
|
| 53 |
+
# Transform image center from source coordinates into output coordinates
|
| 54 |
+
# and then map the new origin to the corner of the output image.
|
| 55 |
+
h, w = self.output_size
|
| 56 |
+
x0, y0, x1, y1 = self.src_rect
|
| 57 |
+
new_coords = coords.astype(np.float32)
|
| 58 |
+
new_coords[:, 0] -= 0.5 * (x0 + x1)
|
| 59 |
+
new_coords[:, 1] -= 0.5 * (y0 + y1)
|
| 60 |
+
new_coords[:, 0] *= w / (x1 - x0)
|
| 61 |
+
new_coords[:, 1] *= h / (y1 - y0)
|
| 62 |
+
new_coords[:, 0] += 0.5 * w
|
| 63 |
+
new_coords[:, 1] += 0.5 * h
|
| 64 |
+
return new_coords
|
| 65 |
+
|
| 66 |
+
def apply_segmentation(self, segmentation):
|
| 67 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 68 |
+
return segmentation
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ResizeTransform(Transform):
|
| 72 |
+
"""
|
| 73 |
+
Resize the image to a target size.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, h, w, new_h, new_w, interp=None):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
h, w (int): original image size
|
| 80 |
+
new_h, new_w (int): new image size
|
| 81 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
| 82 |
+
"""
|
| 83 |
+
# TODO decide on PIL vs opencv
|
| 84 |
+
super().__init__()
|
| 85 |
+
if interp is None:
|
| 86 |
+
interp = Image.BILINEAR
|
| 87 |
+
self._set_attributes(locals())
|
| 88 |
+
|
| 89 |
+
def apply_image(self, img, interp=None):
|
| 90 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 91 |
+
assert len(img.shape) <= 4
|
| 92 |
+
|
| 93 |
+
if img.dtype == np.uint8:
|
| 94 |
+
pil_image = Image.fromarray(img)
|
| 95 |
+
interp_method = interp if interp is not None else self.interp
|
| 96 |
+
pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
|
| 97 |
+
ret = np.asarray(pil_image)
|
| 98 |
+
else:
|
| 99 |
+
# PIL only supports uint8
|
| 100 |
+
img = torch.from_numpy(img)
|
| 101 |
+
shape = list(img.shape)
|
| 102 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
| 103 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
| 104 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
|
| 105 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
|
| 106 |
+
img = F.interpolate(img, (self.new_h, self.new_w), mode=mode, align_corners=False)
|
| 107 |
+
shape[:2] = (self.new_h, self.new_w)
|
| 108 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
| 109 |
+
|
| 110 |
+
return ret
|
| 111 |
+
|
| 112 |
+
def apply_coords(self, coords):
|
| 113 |
+
coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
|
| 114 |
+
coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
|
| 115 |
+
return coords
|
| 116 |
+
|
| 117 |
+
def apply_segmentation(self, segmentation):
|
| 118 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
| 119 |
+
return segmentation
|
| 120 |
+
|
| 121 |
+
def inverse(self):
|
| 122 |
+
return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class RotationTransform(Transform):
|
| 126 |
+
"""
|
| 127 |
+
This method returns a copy of this image, rotated the given
|
| 128 |
+
number of degrees counter clockwise around its center.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, h, w, angle, expand=True, center=None, interp=None):
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
h, w (int): original image size
|
| 135 |
+
angle (float): degrees for rotation
|
| 136 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 137 |
+
rotated image (default), or simply cropped
|
| 138 |
+
center (tuple (width, height)): coordinates of the rotation center
|
| 139 |
+
if left to None, the center will be fit to the center of each image
|
| 140 |
+
center has no effect if expand=True because it only affects shifting
|
| 141 |
+
interp: cv2 interpolation method, default cv2.INTER_LINEAR
|
| 142 |
+
"""
|
| 143 |
+
super().__init__()
|
| 144 |
+
image_center = np.array((w / 2, h / 2))
|
| 145 |
+
if center is None:
|
| 146 |
+
center = image_center
|
| 147 |
+
if interp is None:
|
| 148 |
+
interp = cv2.INTER_LINEAR
|
| 149 |
+
abs_cos, abs_sin = abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle)))
|
| 150 |
+
if expand:
|
| 151 |
+
# find the new width and height bounds
|
| 152 |
+
bound_w, bound_h = np.rint(
|
| 153 |
+
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
|
| 154 |
+
).astype(int)
|
| 155 |
+
else:
|
| 156 |
+
bound_w, bound_h = w, h
|
| 157 |
+
|
| 158 |
+
self._set_attributes(locals())
|
| 159 |
+
self.rm_coords = self.create_rotation_matrix()
|
| 160 |
+
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
|
| 161 |
+
self.rm_image = self.create_rotation_matrix(offset=-0.5)
|
| 162 |
+
|
| 163 |
+
def apply_image(self, img, interp=None):
|
| 164 |
+
"""
|
| 165 |
+
demo should be a numpy array, formatted as Height * Width * Nchannels
|
| 166 |
+
"""
|
| 167 |
+
if len(img) == 0 or self.angle % 360 == 0:
|
| 168 |
+
return img
|
| 169 |
+
assert img.shape[:2] == (self.h, self.w)
|
| 170 |
+
interp = interp if interp is not None else self.interp
|
| 171 |
+
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
|
| 172 |
+
|
| 173 |
+
def apply_coords(self, coords):
|
| 174 |
+
"""
|
| 175 |
+
coords should be a N * 2 array-like, containing N couples of (x, y) points
|
| 176 |
+
"""
|
| 177 |
+
coords = np.asarray(coords, dtype=float)
|
| 178 |
+
if len(coords) == 0 or self.angle % 360 == 0:
|
| 179 |
+
return coords
|
| 180 |
+
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
|
| 181 |
+
|
| 182 |
+
def apply_segmentation(self, segmentation):
|
| 183 |
+
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
|
| 184 |
+
return segmentation
|
| 185 |
+
|
| 186 |
+
def create_rotation_matrix(self, offset=0):
|
| 187 |
+
center = (self.center[0] + offset, self.center[1] + offset)
|
| 188 |
+
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
|
| 189 |
+
if self.expand:
|
| 190 |
+
# Find the coordinates of the center of rotation in the new image
|
| 191 |
+
# The only point for which we know the future coordinates is the center of the image
|
| 192 |
+
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
|
| 193 |
+
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
|
| 194 |
+
# shift the rotation center to the new coordinates
|
| 195 |
+
rm[:, 2] += new_center
|
| 196 |
+
return rm
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def HFlip_rotated_box(transform, rotated_boxes):
|
| 200 |
+
"""
|
| 201 |
+
Apply the horizontal flip transform on rotated boxes.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 205 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 206 |
+
in absolute coordinates.
|
| 207 |
+
"""
|
| 208 |
+
# Transform x_center
|
| 209 |
+
rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
|
| 210 |
+
# Transform angle
|
| 211 |
+
rotated_boxes[:, 4] = -rotated_boxes[:, 4]
|
| 212 |
+
return rotated_boxes
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def Resize_rotated_box(transform, rotated_boxes):
|
| 216 |
+
"""
|
| 217 |
+
Apply the resizing transform on rotated boxes. For details of how these (approximation)
|
| 218 |
+
formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
rotated_boxes (ndarray): Nx5 floating point array of
|
| 222 |
+
(x_center, y_center, width, height, angle_degrees) format
|
| 223 |
+
in absolute coordinates.
|
| 224 |
+
"""
|
| 225 |
+
scale_factor_x = transform.new_w * 1.0 / transform.w
|
| 226 |
+
scale_factor_y = transform.new_h * 1.0 / transform.h
|
| 227 |
+
rotated_boxes[:, 0] *= scale_factor_x
|
| 228 |
+
rotated_boxes[:, 1] *= scale_factor_y
|
| 229 |
+
theta = rotated_boxes[:, 4] * np.pi / 180.0
|
| 230 |
+
c = np.cos(theta)
|
| 231 |
+
s = np.sin(theta)
|
| 232 |
+
rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
|
| 233 |
+
rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
|
| 234 |
+
rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
|
| 235 |
+
|
| 236 |
+
return rotated_boxes
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
|
| 240 |
+
NoOpTransform.register_type("rotated_box", lambda t, x: x)
|
| 241 |
+
ResizeTransform.register_type("rotated_box", Resize_rotated_box)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/data/transforms/transform_gen.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
# File: transformer.py
|
| 4 |
+
|
| 5 |
+
import inspect
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pprint
|
| 8 |
+
import sys
|
| 9 |
+
from abc import ABCMeta, abstractmethod
|
| 10 |
+
from fvcore.transforms.transform import (
|
| 11 |
+
BlendTransform,
|
| 12 |
+
CropTransform,
|
| 13 |
+
HFlipTransform,
|
| 14 |
+
NoOpTransform,
|
| 15 |
+
Transform,
|
| 16 |
+
TransformList,
|
| 17 |
+
VFlipTransform,
|
| 18 |
+
)
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
from .transform import ExtentTransform, ResizeTransform, RotationTransform
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"RandomApply",
|
| 25 |
+
"RandomBrightness",
|
| 26 |
+
"RandomContrast",
|
| 27 |
+
"RandomCrop",
|
| 28 |
+
"RandomExtent",
|
| 29 |
+
"RandomFlip",
|
| 30 |
+
"RandomSaturation",
|
| 31 |
+
"RandomLighting",
|
| 32 |
+
"RandomRotation",
|
| 33 |
+
"Resize",
|
| 34 |
+
"ResizeShortestEdge",
|
| 35 |
+
"TransformGen",
|
| 36 |
+
"apply_transform_gens",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def check_dtype(img):
|
| 41 |
+
assert isinstance(img, np.ndarray), "[TransformGen] Needs an numpy array, but got a {}!".format(
|
| 42 |
+
type(img)
|
| 43 |
+
)
|
| 44 |
+
assert not isinstance(img.dtype, np.integer) or (
|
| 45 |
+
img.dtype == np.uint8
|
| 46 |
+
), "[TransformGen] Got image of type {}, use uint8 or floating points instead!".format(
|
| 47 |
+
img.dtype
|
| 48 |
+
)
|
| 49 |
+
assert img.ndim in [2, 3], img.ndim
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TransformGen(metaclass=ABCMeta):
|
| 53 |
+
"""
|
| 54 |
+
TransformGen takes an image of type uint8 in range [0, 255], or
|
| 55 |
+
floating point in range [0, 1] or [0, 255] as input.
|
| 56 |
+
|
| 57 |
+
It creates a :class:`Transform` based on the given image, sometimes with randomness.
|
| 58 |
+
The transform can then be used to transform images
|
| 59 |
+
or other data (boxes, points, annotations, etc.) associated with it.
|
| 60 |
+
|
| 61 |
+
The assumption made in this class
|
| 62 |
+
is that the image itself is sufficient to instantiate a transform.
|
| 63 |
+
When this assumption is not true, you need to create the transforms by your own.
|
| 64 |
+
|
| 65 |
+
A list of `TransformGen` can be applied with :func:`apply_transform_gens`.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def _init(self, params=None):
|
| 69 |
+
if params:
|
| 70 |
+
for k, v in params.items():
|
| 71 |
+
if k != "self" and not k.startswith("_"):
|
| 72 |
+
setattr(self, k, v)
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def get_transform(self, img):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def _rand_range(self, low=1.0, high=None, size=None):
|
| 79 |
+
"""
|
| 80 |
+
Uniform float random number between low and high.
|
| 81 |
+
"""
|
| 82 |
+
if high is None:
|
| 83 |
+
low, high = 0, low
|
| 84 |
+
if size is None:
|
| 85 |
+
size = []
|
| 86 |
+
return np.random.uniform(low, high, size)
|
| 87 |
+
|
| 88 |
+
def __repr__(self):
|
| 89 |
+
"""
|
| 90 |
+
Produce something like:
|
| 91 |
+
"MyTransformGen(field1={self.field1}, field2={self.field2})"
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
sig = inspect.signature(self.__init__)
|
| 95 |
+
classname = type(self).__name__
|
| 96 |
+
argstr = []
|
| 97 |
+
for name, param in sig.parameters.items():
|
| 98 |
+
assert (
|
| 99 |
+
param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
|
| 100 |
+
), "The default __repr__ doesn't support *args or **kwargs"
|
| 101 |
+
assert hasattr(self, name), (
|
| 102 |
+
"Attribute {} not found! "
|
| 103 |
+
"Default __repr__ only works if attributes match the constructor.".format(name)
|
| 104 |
+
)
|
| 105 |
+
attr = getattr(self, name)
|
| 106 |
+
default = param.default
|
| 107 |
+
if default is attr:
|
| 108 |
+
continue
|
| 109 |
+
argstr.append("{}={}".format(name, pprint.pformat(attr)))
|
| 110 |
+
return "{}({})".format(classname, ", ".join(argstr))
|
| 111 |
+
except AssertionError:
|
| 112 |
+
return super().__repr__()
|
| 113 |
+
|
| 114 |
+
__str__ = __repr__
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class RandomApply(TransformGen):
|
| 118 |
+
"""
|
| 119 |
+
Randomly apply the wrapper transformation with a given probability.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, transform, prob=0.5):
|
| 123 |
+
"""
|
| 124 |
+
Args:
|
| 125 |
+
transform (Transform, TransformGen): the transform to be wrapped
|
| 126 |
+
by the `RandomApply`. The `transform` can either be a
|
| 127 |
+
`Transform` or `TransformGen` instance.
|
| 128 |
+
prob (float): probability between 0.0 and 1.0 that
|
| 129 |
+
the wrapper transformation is applied
|
| 130 |
+
"""
|
| 131 |
+
super().__init__()
|
| 132 |
+
assert isinstance(transform, (Transform, TransformGen)), (
|
| 133 |
+
f"The given transform must either be a Transform or TransformGen instance. "
|
| 134 |
+
f"Not {type(transform)}"
|
| 135 |
+
)
|
| 136 |
+
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
|
| 137 |
+
self.prob = prob
|
| 138 |
+
self.transform = transform
|
| 139 |
+
|
| 140 |
+
def get_transform(self, img):
|
| 141 |
+
do = self._rand_range() < self.prob
|
| 142 |
+
if do:
|
| 143 |
+
if isinstance(self.transform, TransformGen):
|
| 144 |
+
return self.transform.get_transform(img)
|
| 145 |
+
else:
|
| 146 |
+
return self.transform
|
| 147 |
+
else:
|
| 148 |
+
return NoOpTransform()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class RandomFlip(TransformGen):
|
| 152 |
+
"""
|
| 153 |
+
Flip the image horizontally or vertically with the given probability.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
|
| 157 |
+
"""
|
| 158 |
+
Args:
|
| 159 |
+
prob (float): probability of flip.
|
| 160 |
+
horizontal (boolean): whether to apply horizontal flipping
|
| 161 |
+
vertical (boolean): whether to apply vertical flipping
|
| 162 |
+
"""
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
if horizontal and vertical:
|
| 166 |
+
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
|
| 167 |
+
if not horizontal and not vertical:
|
| 168 |
+
raise ValueError("At least one of horiz or vert has to be True!")
|
| 169 |
+
self._init(locals())
|
| 170 |
+
|
| 171 |
+
def get_transform(self, img):
|
| 172 |
+
h, w = img.shape[:2]
|
| 173 |
+
do = self._rand_range() < self.prob
|
| 174 |
+
if do:
|
| 175 |
+
if self.horizontal:
|
| 176 |
+
return HFlipTransform(w)
|
| 177 |
+
elif self.vertical:
|
| 178 |
+
return VFlipTransform(h)
|
| 179 |
+
else:
|
| 180 |
+
return NoOpTransform()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Resize(TransformGen):
|
| 184 |
+
""" Resize image to a target size"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, shape, interp=Image.BILINEAR):
|
| 187 |
+
"""
|
| 188 |
+
Args:
|
| 189 |
+
shape: (h, w) tuple or a int
|
| 190 |
+
interp: PIL interpolation method
|
| 191 |
+
"""
|
| 192 |
+
if isinstance(shape, int):
|
| 193 |
+
shape = (shape, shape)
|
| 194 |
+
shape = tuple(shape)
|
| 195 |
+
self._init(locals())
|
| 196 |
+
|
| 197 |
+
def get_transform(self, img):
|
| 198 |
+
return ResizeTransform(
|
| 199 |
+
img.shape[0], img.shape[1], self.shape[0], self.shape[1], self.interp
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class ResizeShortestEdge(TransformGen):
|
| 204 |
+
"""
|
| 205 |
+
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
| 206 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
|
| 211 |
+
):
|
| 212 |
+
"""
|
| 213 |
+
Args:
|
| 214 |
+
short_edge_length (list[int]): If ``sample_style=="range"``,
|
| 215 |
+
a [min, max] interval from which to sample the shortest edge length.
|
| 216 |
+
If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
|
| 217 |
+
max_size (int): maximum allowed longest edge length.
|
| 218 |
+
sample_style (str): either "range" or "choice".
|
| 219 |
+
"""
|
| 220 |
+
super().__init__()
|
| 221 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 222 |
+
|
| 223 |
+
self.is_range = sample_style == "range"
|
| 224 |
+
if isinstance(short_edge_length, int):
|
| 225 |
+
short_edge_length = (short_edge_length, short_edge_length)
|
| 226 |
+
self._init(locals())
|
| 227 |
+
|
| 228 |
+
def get_transform(self, img):
|
| 229 |
+
h, w = img.shape[:2]
|
| 230 |
+
|
| 231 |
+
if self.is_range:
|
| 232 |
+
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
| 233 |
+
else:
|
| 234 |
+
size = np.random.choice(self.short_edge_length)
|
| 235 |
+
if size == 0:
|
| 236 |
+
return NoOpTransform()
|
| 237 |
+
|
| 238 |
+
scale = size * 1.0 / min(h, w)
|
| 239 |
+
if h < w:
|
| 240 |
+
newh, neww = size, scale * w
|
| 241 |
+
else:
|
| 242 |
+
newh, neww = scale * h, size
|
| 243 |
+
if max(newh, neww) > self.max_size:
|
| 244 |
+
scale = self.max_size * 1.0 / max(newh, neww)
|
| 245 |
+
newh = newh * scale
|
| 246 |
+
neww = neww * scale
|
| 247 |
+
neww = int(neww + 0.5)
|
| 248 |
+
newh = int(newh + 0.5)
|
| 249 |
+
return ResizeTransform(h, w, newh, neww, self.interp)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class RandomRotation(TransformGen):
|
| 253 |
+
"""
|
| 254 |
+
This method returns a copy of this image, rotated the given
|
| 255 |
+
number of degrees counter clockwise around the given center.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
angle (list[float]): If ``sample_style=="range"``,
|
| 262 |
+
a [min, max] interval from which to sample the angle (in degrees).
|
| 263 |
+
If ``sample_style=="choice"``, a list of angles to sample from
|
| 264 |
+
expand (bool): choose if the image should be resized to fit the whole
|
| 265 |
+
rotated image (default), or simply cropped
|
| 266 |
+
center (list[[float, float]]): If ``sample_style=="range"``,
|
| 267 |
+
a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
|
| 268 |
+
[0, 0] being the top left of the image and [1, 1] the bottom right.
|
| 269 |
+
If ``sample_style=="choice"``, a list of centers to sample from
|
| 270 |
+
Default: None, which means that the center of rotation is the center of the image
|
| 271 |
+
center has no effect if expand=True because it only affects shifting
|
| 272 |
+
"""
|
| 273 |
+
super().__init__()
|
| 274 |
+
assert sample_style in ["range", "choice"], sample_style
|
| 275 |
+
self.is_range = sample_style == "range"
|
| 276 |
+
if isinstance(angle, (float, int)):
|
| 277 |
+
angle = (angle, angle)
|
| 278 |
+
if center is not None and isinstance(center[0], (float, int)):
|
| 279 |
+
center = (center, center)
|
| 280 |
+
self._init(locals())
|
| 281 |
+
|
| 282 |
+
def get_transform(self, img):
|
| 283 |
+
h, w = img.shape[:2]
|
| 284 |
+
center = None
|
| 285 |
+
if self.is_range:
|
| 286 |
+
angle = np.random.uniform(self.angle[0], self.angle[1])
|
| 287 |
+
if self.center is not None:
|
| 288 |
+
center = (
|
| 289 |
+
np.random.uniform(self.center[0][0], self.center[1][0]),
|
| 290 |
+
np.random.uniform(self.center[0][1], self.center[1][1]),
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
angle = np.random.choice(self.angle)
|
| 294 |
+
if self.center is not None:
|
| 295 |
+
center = np.random.choice(self.center)
|
| 296 |
+
|
| 297 |
+
if center is not None:
|
| 298 |
+
center = (w * center[0], h * center[1]) # Convert to absolute coordinates
|
| 299 |
+
|
| 300 |
+
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class RandomCrop(TransformGen):
|
| 304 |
+
"""
|
| 305 |
+
Randomly crop a subimage out of an image.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def __init__(self, crop_type: str, crop_size):
|
| 309 |
+
"""
|
| 310 |
+
Args:
|
| 311 |
+
crop_type (str): one of "relative_range", "relative", "absolute".
|
| 312 |
+
See `config/defaults.py` for explanation.
|
| 313 |
+
crop_size (tuple[float]): the relative ratio or absolute pixels of
|
| 314 |
+
height and width
|
| 315 |
+
"""
|
| 316 |
+
super().__init__()
|
| 317 |
+
assert crop_type in ["relative_range", "relative", "absolute"]
|
| 318 |
+
self._init(locals())
|
| 319 |
+
|
| 320 |
+
def get_transform(self, img):
|
| 321 |
+
h, w = img.shape[:2]
|
| 322 |
+
croph, cropw = self.get_crop_size((h, w))
|
| 323 |
+
assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
|
| 324 |
+
h0 = np.random.randint(h - croph + 1)
|
| 325 |
+
w0 = np.random.randint(w - cropw + 1)
|
| 326 |
+
return CropTransform(w0, h0, cropw, croph)
|
| 327 |
+
|
| 328 |
+
def get_crop_size(self, image_size):
|
| 329 |
+
"""
|
| 330 |
+
Args:
|
| 331 |
+
image_size (tuple): height, width
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
crop_size (tuple): height, width in absolute pixels
|
| 335 |
+
"""
|
| 336 |
+
h, w = image_size
|
| 337 |
+
if self.crop_type == "relative":
|
| 338 |
+
ch, cw = self.crop_size
|
| 339 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 340 |
+
elif self.crop_type == "relative_range":
|
| 341 |
+
crop_size = np.asarray(self.crop_size, dtype=np.float32)
|
| 342 |
+
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
|
| 343 |
+
return int(h * ch + 0.5), int(w * cw + 0.5)
|
| 344 |
+
elif self.crop_type == "absolute":
|
| 345 |
+
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
|
| 346 |
+
else:
|
| 347 |
+
NotImplementedError("Unknown crop type {}".format(self.crop_type))
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class RandomExtent(TransformGen):
|
| 351 |
+
"""
|
| 352 |
+
Outputs an image by cropping a random "subrect" of the source image.
|
| 353 |
+
|
| 354 |
+
The subrect can be parameterized to include pixels outside the source image,
|
| 355 |
+
in which case they will be set to zeros (i.e. black). The size of the output
|
| 356 |
+
image will vary with the size of the random subrect.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
def __init__(self, scale_range, shift_range):
|
| 360 |
+
"""
|
| 361 |
+
Args:
|
| 362 |
+
output_size (h, w): Dimensions of output image
|
| 363 |
+
scale_range (l, h): Range of input-to-output size scaling factor
|
| 364 |
+
shift_range (x, y): Range of shifts of the cropped subrect. The rect
|
| 365 |
+
is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
|
| 366 |
+
where (w, h) is the (width, height) of the input image. Set each
|
| 367 |
+
component to zero to crop at the image's center.
|
| 368 |
+
"""
|
| 369 |
+
super().__init__()
|
| 370 |
+
self._init(locals())
|
| 371 |
+
|
| 372 |
+
def get_transform(self, img):
|
| 373 |
+
img_h, img_w = img.shape[:2]
|
| 374 |
+
|
| 375 |
+
# Initialize src_rect to fit the input image.
|
| 376 |
+
src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
|
| 377 |
+
|
| 378 |
+
# Apply a random scaling to the src_rect.
|
| 379 |
+
src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
|
| 380 |
+
|
| 381 |
+
# Apply a random shift to the coordinates origin.
|
| 382 |
+
src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
|
| 383 |
+
src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
|
| 384 |
+
|
| 385 |
+
# Map src_rect coordinates into image coordinates (center at corner).
|
| 386 |
+
src_rect[0::2] += 0.5 * img_w
|
| 387 |
+
src_rect[1::2] += 0.5 * img_h
|
| 388 |
+
|
| 389 |
+
return ExtentTransform(
|
| 390 |
+
src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
|
| 391 |
+
output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class RandomContrast(TransformGen):
|
| 396 |
+
"""
|
| 397 |
+
Randomly transforms image contrast.
|
| 398 |
+
|
| 399 |
+
Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 400 |
+
- intensity < 1 will reduce contrast
|
| 401 |
+
- intensity = 1 will preserve the input image
|
| 402 |
+
- intensity > 1 will increase contrast
|
| 403 |
+
|
| 404 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self, intensity_min, intensity_max):
|
| 408 |
+
"""
|
| 409 |
+
Args:
|
| 410 |
+
intensity_min (float): Minimum augmentation
|
| 411 |
+
intensity_max (float): Maximum augmentation
|
| 412 |
+
"""
|
| 413 |
+
super().__init__()
|
| 414 |
+
self._init(locals())
|
| 415 |
+
|
| 416 |
+
def get_transform(self, img):
|
| 417 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 418 |
+
return BlendTransform(src_image=img.mean(), src_weight=1 - w, dst_weight=w)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class RandomBrightness(TransformGen):
|
| 422 |
+
"""
|
| 423 |
+
Randomly transforms image brightness.
|
| 424 |
+
|
| 425 |
+
Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 426 |
+
- intensity < 1 will reduce brightness
|
| 427 |
+
- intensity = 1 will preserve the input image
|
| 428 |
+
- intensity > 1 will increase brightness
|
| 429 |
+
|
| 430 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
def __init__(self, intensity_min, intensity_max):
|
| 434 |
+
"""
|
| 435 |
+
Args:
|
| 436 |
+
intensity_min (float): Minimum augmentation
|
| 437 |
+
intensity_max (float): Maximum augmentation
|
| 438 |
+
"""
|
| 439 |
+
super().__init__()
|
| 440 |
+
self._init(locals())
|
| 441 |
+
|
| 442 |
+
def get_transform(self, img):
|
| 443 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 444 |
+
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class RandomSaturation(TransformGen):
|
| 448 |
+
"""
|
| 449 |
+
Randomly transforms image saturation.
|
| 450 |
+
|
| 451 |
+
Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
|
| 452 |
+
- intensity < 1 will reduce saturation (make the image more grayscale)
|
| 453 |
+
- intensity = 1 will preserve the input image
|
| 454 |
+
- intensity > 1 will increase saturation
|
| 455 |
+
|
| 456 |
+
See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def __init__(self, intensity_min, intensity_max):
|
| 460 |
+
"""
|
| 461 |
+
Args:
|
| 462 |
+
intensity_min (float): Minimum augmentation (1 preserves input).
|
| 463 |
+
intensity_max (float): Maximum augmentation (1 preserves input).
|
| 464 |
+
"""
|
| 465 |
+
super().__init__()
|
| 466 |
+
self._init(locals())
|
| 467 |
+
|
| 468 |
+
def get_transform(self, img):
|
| 469 |
+
assert img.shape[-1] == 3, "Saturation only works on RGB images"
|
| 470 |
+
w = np.random.uniform(self.intensity_min, self.intensity_max)
|
| 471 |
+
grayscale = img.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
|
| 472 |
+
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class RandomLighting(TransformGen):
|
| 476 |
+
"""
|
| 477 |
+
Randomly transforms image color using fixed PCA over ImageNet.
|
| 478 |
+
|
| 479 |
+
The degree of color jittering is randomly sampled via a normal distribution,
|
| 480 |
+
with standard deviation given by the scale parameter.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
def __init__(self, scale):
|
| 484 |
+
"""
|
| 485 |
+
Args:
|
| 486 |
+
scale (float): Standard deviation of principal component weighting.
|
| 487 |
+
"""
|
| 488 |
+
super().__init__()
|
| 489 |
+
self._init(locals())
|
| 490 |
+
self.eigen_vecs = np.array(
|
| 491 |
+
[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
|
| 492 |
+
)
|
| 493 |
+
self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
|
| 494 |
+
|
| 495 |
+
def get_transform(self, img):
|
| 496 |
+
assert img.shape[-1] == 3, "Saturation only works on RGB images"
|
| 497 |
+
weights = np.random.normal(scale=self.scale, size=3)
|
| 498 |
+
return BlendTransform(
|
| 499 |
+
src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def apply_transform_gens(transform_gens, img):
|
| 504 |
+
"""
|
| 505 |
+
Apply a list of :class:`TransformGen` or :class:`Transform` on the input image, and
|
| 506 |
+
returns the transformed image and a list of transforms.
|
| 507 |
+
|
| 508 |
+
We cannot simply create and return all transforms without
|
| 509 |
+
applying it to the image, because a subsequent transform may
|
| 510 |
+
need the output of the previous one.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
transform_gens (list): list of :class:`TransformGen` or :class:`Transform` instance to
|
| 514 |
+
be applied.
|
| 515 |
+
img (ndarray): uint8 or floating point images with 1 or 3 channels.
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
ndarray: the transformed image
|
| 519 |
+
TransformList: contain the transforms that's used.
|
| 520 |
+
"""
|
| 521 |
+
for g in transform_gens:
|
| 522 |
+
assert isinstance(g, (Transform, TransformGen)), g
|
| 523 |
+
|
| 524 |
+
check_dtype(img)
|
| 525 |
+
|
| 526 |
+
tfms = []
|
| 527 |
+
for g in transform_gens:
|
| 528 |
+
tfm = g.get_transform(img) if isinstance(g, TransformGen) else g
|
| 529 |
+
assert isinstance(
|
| 530 |
+
tfm, Transform
|
| 531 |
+
), "TransformGen {} must return an instance of Transform! Got {} instead".format(g, tfm)
|
| 532 |
+
img = tfm.apply_image(img)
|
| 533 |
+
tfms.append(tfm)
|
| 534 |
+
return img, TransformList(tfms)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator
|
| 3 |
+
from .coco_evaluation import COCOEvaluator
|
| 4 |
+
from .rotated_coco_evaluation import RotatedCOCOEvaluator
|
| 5 |
+
from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
|
| 6 |
+
from .lvis_evaluation import LVISEvaluator
|
| 7 |
+
from .panoptic_evaluation import COCOPanopticEvaluator
|
| 8 |
+
from .pascal_voc_evaluation import PascalVOCDetectionEvaluator
|
| 9 |
+
from .sem_seg_evaluation import SemSegEvaluator
|
| 10 |
+
from .testing import print_csv_format, verify_results
|
| 11 |
+
|
| 12 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/cityscapes_evaluation.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import torch
|
| 9 |
+
from fvcore.common.file_io import PathManager
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from detectron2.data import MetadataCatalog
|
| 13 |
+
from detectron2.utils import comm
|
| 14 |
+
|
| 15 |
+
from .evaluator import DatasetEvaluator
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CityscapesEvaluator(DatasetEvaluator):
|
| 19 |
+
"""
|
| 20 |
+
Base class for evaluation using cityscapes API.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, dataset_name):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
dataset_name (str): the name of the dataset.
|
| 27 |
+
It must have the following metadata associated with it:
|
| 28 |
+
"thing_classes", "gt_dir".
|
| 29 |
+
"""
|
| 30 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 31 |
+
self._cpu_device = torch.device("cpu")
|
| 32 |
+
self._logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
def reset(self):
|
| 35 |
+
self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_")
|
| 36 |
+
self._temp_dir = self._working_dir.name
|
| 37 |
+
# All workers will write to the same results directory
|
| 38 |
+
# TODO this does not work in distributed training
|
| 39 |
+
self._temp_dir = comm.all_gather(self._temp_dir)[0]
|
| 40 |
+
if self._temp_dir != self._working_dir.name:
|
| 41 |
+
self._working_dir.cleanup()
|
| 42 |
+
self._logger.info(
|
| 43 |
+
"Writing cityscapes results to temporary directory {} ...".format(self._temp_dir)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CityscapesInstanceEvaluator(CityscapesEvaluator):
|
| 48 |
+
"""
|
| 49 |
+
Evaluate instance segmentation results using cityscapes API.
|
| 50 |
+
|
| 51 |
+
Note:
|
| 52 |
+
* It does not work in multi-machine distributed training.
|
| 53 |
+
* It contains a synchronization, therefore has to be used on all ranks.
|
| 54 |
+
* Only the main process runs evaluation.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def process(self, inputs, outputs):
|
| 58 |
+
from cityscapesscripts.helpers.labels import name2label
|
| 59 |
+
|
| 60 |
+
for input, output in zip(inputs, outputs):
|
| 61 |
+
file_name = input["file_name"]
|
| 62 |
+
basename = os.path.splitext(os.path.basename(file_name))[0]
|
| 63 |
+
pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt")
|
| 64 |
+
|
| 65 |
+
output = output["instances"].to(self._cpu_device)
|
| 66 |
+
num_instances = len(output)
|
| 67 |
+
with open(pred_txt, "w") as fout:
|
| 68 |
+
for i in range(num_instances):
|
| 69 |
+
pred_class = output.pred_classes[i]
|
| 70 |
+
classes = self._metadata.thing_classes[pred_class]
|
| 71 |
+
class_id = name2label[classes].id
|
| 72 |
+
score = output.scores[i]
|
| 73 |
+
mask = output.pred_masks[i].numpy().astype("uint8")
|
| 74 |
+
png_filename = os.path.join(
|
| 75 |
+
self._temp_dir, basename + "_{}_{}.png".format(i, classes)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
Image.fromarray(mask * 255).save(png_filename)
|
| 79 |
+
fout.write("{} {} {}\n".format(os.path.basename(png_filename), class_id, score))
|
| 80 |
+
|
| 81 |
+
def evaluate(self):
|
| 82 |
+
"""
|
| 83 |
+
Returns:
|
| 84 |
+
dict: has a key "segm", whose value is a dict of "AP" and "AP50".
|
| 85 |
+
"""
|
| 86 |
+
comm.synchronize()
|
| 87 |
+
if comm.get_rank() > 0:
|
| 88 |
+
return
|
| 89 |
+
import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval
|
| 90 |
+
|
| 91 |
+
self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
|
| 92 |
+
|
| 93 |
+
# set some global states in cityscapes evaluation API, before evaluating
|
| 94 |
+
cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
|
| 95 |
+
cityscapes_eval.args.predictionWalk = None
|
| 96 |
+
cityscapes_eval.args.JSONOutput = False
|
| 97 |
+
cityscapes_eval.args.colorized = False
|
| 98 |
+
cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json")
|
| 99 |
+
|
| 100 |
+
# These lines are adopted from
|
| 101 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
|
| 102 |
+
gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
|
| 103 |
+
groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png"))
|
| 104 |
+
assert len(
|
| 105 |
+
groundTruthImgList
|
| 106 |
+
), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
|
| 107 |
+
cityscapes_eval.args.groundTruthSearch
|
| 108 |
+
)
|
| 109 |
+
predictionImgList = []
|
| 110 |
+
for gt in groundTruthImgList:
|
| 111 |
+
predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args))
|
| 112 |
+
results = cityscapes_eval.evaluateImgLists(
|
| 113 |
+
predictionImgList, groundTruthImgList, cityscapes_eval.args
|
| 114 |
+
)["averages"]
|
| 115 |
+
|
| 116 |
+
ret = OrderedDict()
|
| 117 |
+
ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100}
|
| 118 |
+
self._working_dir.cleanup()
|
| 119 |
+
return ret
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class CityscapesSemSegEvaluator(CityscapesEvaluator):
|
| 123 |
+
"""
|
| 124 |
+
Evaluate semantic segmentation results using cityscapes API.
|
| 125 |
+
|
| 126 |
+
Note:
|
| 127 |
+
* It does not work in multi-machine distributed training.
|
| 128 |
+
* It contains a synchronization, therefore has to be used on all ranks.
|
| 129 |
+
* Only the main process runs evaluation.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def process(self, inputs, outputs):
|
| 133 |
+
from cityscapesscripts.helpers.labels import trainId2label
|
| 134 |
+
|
| 135 |
+
for input, output in zip(inputs, outputs):
|
| 136 |
+
file_name = input["file_name"]
|
| 137 |
+
basename = os.path.splitext(os.path.basename(file_name))[0]
|
| 138 |
+
pred_filename = os.path.join(self._temp_dir, basename + "_pred.png")
|
| 139 |
+
|
| 140 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy()
|
| 141 |
+
pred = 255 * np.ones(output.shape, dtype=np.uint8)
|
| 142 |
+
for train_id, label in trainId2label.items():
|
| 143 |
+
if label.ignoreInEval:
|
| 144 |
+
continue
|
| 145 |
+
pred[output == train_id] = label.id
|
| 146 |
+
Image.fromarray(pred).save(pred_filename)
|
| 147 |
+
|
| 148 |
+
def evaluate(self):
|
| 149 |
+
comm.synchronize()
|
| 150 |
+
if comm.get_rank() > 0:
|
| 151 |
+
return
|
| 152 |
+
# Load the Cityscapes eval script *after* setting the required env var,
|
| 153 |
+
# since the script reads CITYSCAPES_DATASET into global variables at load time.
|
| 154 |
+
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval
|
| 155 |
+
|
| 156 |
+
self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
|
| 157 |
+
|
| 158 |
+
# set some global states in cityscapes evaluation API, before evaluating
|
| 159 |
+
cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
|
| 160 |
+
cityscapes_eval.args.predictionWalk = None
|
| 161 |
+
cityscapes_eval.args.JSONOutput = False
|
| 162 |
+
cityscapes_eval.args.colorized = False
|
| 163 |
+
|
| 164 |
+
# These lines are adopted from
|
| 165 |
+
# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa
|
| 166 |
+
gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
|
| 167 |
+
groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png"))
|
| 168 |
+
assert len(
|
| 169 |
+
groundTruthImgList
|
| 170 |
+
), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
|
| 171 |
+
cityscapes_eval.args.groundTruthSearch
|
| 172 |
+
)
|
| 173 |
+
predictionImgList = []
|
| 174 |
+
for gt in groundTruthImgList:
|
| 175 |
+
predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt))
|
| 176 |
+
results = cityscapes_eval.evaluateImgLists(
|
| 177 |
+
predictionImgList, groundTruthImgList, cityscapes_eval.args
|
| 178 |
+
)
|
| 179 |
+
ret = OrderedDict()
|
| 180 |
+
ret["sem_seg"] = {
|
| 181 |
+
"IoU": 100.0 * results["averageScoreClasses"],
|
| 182 |
+
"iIoU": 100.0 * results["averageScoreInstClasses"],
|
| 183 |
+
"IoU_sup": 100.0 * results["averageScoreCategories"],
|
| 184 |
+
"iIoU_sup": 100.0 * results["averageScoreInstCategories"],
|
| 185 |
+
}
|
| 186 |
+
self._working_dir.cleanup()
|
| 187 |
+
return ret
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/coco_evaluation.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import contextlib
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import itertools
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import pycocotools.mask as mask_util
|
| 13 |
+
import torch
|
| 14 |
+
from fvcore.common.file_io import PathManager
|
| 15 |
+
from pycocotools.coco import COCO
|
| 16 |
+
from pycocotools.cocoeval import COCOeval
|
| 17 |
+
from tabulate import tabulate
|
| 18 |
+
|
| 19 |
+
import detectron2.utils.comm as comm
|
| 20 |
+
from detectron2.data import MetadataCatalog
|
| 21 |
+
from detectron2.data.datasets.coco import convert_to_coco_json
|
| 22 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
| 23 |
+
from detectron2.utils.logger import create_small_table
|
| 24 |
+
|
| 25 |
+
from .evaluator import DatasetEvaluator
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class COCOEvaluator(DatasetEvaluator):
|
| 29 |
+
"""
|
| 30 |
+
Evaluate object proposal, instance detection/segmentation, keypoint detection
|
| 31 |
+
outputs using COCO's metrics and APIs.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dataset_name, cfg, distributed, output_dir=None):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 38 |
+
It must have either the following corresponding metadata:
|
| 39 |
+
|
| 40 |
+
"json_file": the path to the COCO format annotation
|
| 41 |
+
|
| 42 |
+
Or it must be in detectron2's standard dataset format
|
| 43 |
+
so it can be converted to COCO format automatically.
|
| 44 |
+
cfg (CfgNode): config instance
|
| 45 |
+
distributed (True): if True, will collect results from all ranks and run evaluation
|
| 46 |
+
in the main process.
|
| 47 |
+
Otherwise, will evaluate the results in the current process.
|
| 48 |
+
output_dir (str): optional, an output directory to dump all
|
| 49 |
+
results predicted on the dataset. The dump contains two files:
|
| 50 |
+
|
| 51 |
+
1. "instance_predictions.pth" a file in torch serialization
|
| 52 |
+
format that contains all the raw original predictions.
|
| 53 |
+
2. "coco_instances_results.json" a json file in COCO's result
|
| 54 |
+
format.
|
| 55 |
+
"""
|
| 56 |
+
self._tasks = self._tasks_from_config(cfg)
|
| 57 |
+
self._distributed = distributed
|
| 58 |
+
self._output_dir = output_dir
|
| 59 |
+
|
| 60 |
+
self._cpu_device = torch.device("cpu")
|
| 61 |
+
self._logger = logging.getLogger(__name__)
|
| 62 |
+
|
| 63 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 64 |
+
if not hasattr(self._metadata, "json_file"):
|
| 65 |
+
self._logger.warning(
|
| 66 |
+
f"json_file was not found in MetaDataCatalog for '{dataset_name}'."
|
| 67 |
+
" Trying to convert it to COCO format ..."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
|
| 71 |
+
self._metadata.json_file = cache_path
|
| 72 |
+
convert_to_coco_json(dataset_name, cache_path)
|
| 73 |
+
|
| 74 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
| 75 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 76 |
+
self._coco_api = COCO(json_file)
|
| 77 |
+
|
| 78 |
+
self._kpt_oks_sigmas = cfg.TEST.KEYPOINT_OKS_SIGMAS
|
| 79 |
+
# Test set json files do not contain annotations (evaluation must be
|
| 80 |
+
# performed using the COCO evaluation server).
|
| 81 |
+
self._do_evaluation = "annotations" in self._coco_api.split_name
|
| 82 |
+
|
| 83 |
+
def reset(self):
|
| 84 |
+
self._predictions = []
|
| 85 |
+
|
| 86 |
+
def _tasks_from_config(self, cfg):
|
| 87 |
+
"""
|
| 88 |
+
Returns:
|
| 89 |
+
tuple[str]: tasks that can be evaluated under the given configuration.
|
| 90 |
+
"""
|
| 91 |
+
tasks = ("bbox",)
|
| 92 |
+
if cfg.MODEL.MASK_ON:
|
| 93 |
+
tasks = tasks + ("segm",)
|
| 94 |
+
if cfg.MODEL.KEYPOINT_ON:
|
| 95 |
+
tasks = tasks + ("keypoints",)
|
| 96 |
+
return tasks
|
| 97 |
+
|
| 98 |
+
def process(self, inputs, outputs):
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
| 102 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 103 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 104 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
| 105 |
+
"instances" that contains :class:`Instances`.
|
| 106 |
+
"""
|
| 107 |
+
for input, output in zip(inputs, outputs):
|
| 108 |
+
prediction = {"image_id": input["image_id"]}
|
| 109 |
+
|
| 110 |
+
# TODO this is ugly
|
| 111 |
+
if "instances" in output:
|
| 112 |
+
instances = output["instances"].to(self._cpu_device)
|
| 113 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
| 114 |
+
if "proposals" in output:
|
| 115 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
| 116 |
+
self._predictions.append(prediction)
|
| 117 |
+
|
| 118 |
+
def evaluate(self):
|
| 119 |
+
if self._distributed:
|
| 120 |
+
comm.synchronize()
|
| 121 |
+
predictions = comm.gather(self._predictions, dst=0)
|
| 122 |
+
predictions = list(itertools.chain(*predictions))
|
| 123 |
+
|
| 124 |
+
if not comm.is_main_process():
|
| 125 |
+
return {}
|
| 126 |
+
else:
|
| 127 |
+
predictions = self._predictions
|
| 128 |
+
|
| 129 |
+
if len(predictions) == 0:
|
| 130 |
+
self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
|
| 131 |
+
return {}
|
| 132 |
+
|
| 133 |
+
if self._output_dir:
|
| 134 |
+
PathManager.mkdirs(self._output_dir)
|
| 135 |
+
file_path = os.path.join(self._output_dir, "instances_predictions.pth")
|
| 136 |
+
with PathManager.open(file_path, "wb") as f:
|
| 137 |
+
torch.save(predictions, f)
|
| 138 |
+
|
| 139 |
+
self._results = OrderedDict()
|
| 140 |
+
if "proposals" in predictions[0]:
|
| 141 |
+
self._eval_box_proposals(predictions)
|
| 142 |
+
if "instances" in predictions[0]:
|
| 143 |
+
self._eval_predictions(set(self._tasks), predictions)
|
| 144 |
+
# Copy so the caller can do whatever with results
|
| 145 |
+
return copy.deepcopy(self._results)
|
| 146 |
+
|
| 147 |
+
def _eval_predictions(self, tasks, predictions):
|
| 148 |
+
"""
|
| 149 |
+
Evaluate predictions on the given tasks.
|
| 150 |
+
Fill self._results with the metrics of the tasks.
|
| 151 |
+
"""
|
| 152 |
+
self._logger.info("Preparing results for COCO format ...")
|
| 153 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
| 154 |
+
|
| 155 |
+
# unmap the category ids for COCO
|
| 156 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
| 157 |
+
reverse_id_mapping = {
|
| 158 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
| 159 |
+
}
|
| 160 |
+
for result in coco_results:
|
| 161 |
+
category_id = result["category_id"]
|
| 162 |
+
assert (
|
| 163 |
+
category_id in reverse_id_mapping
|
| 164 |
+
), "A prediction has category_id={}, which is not available in the dataset.".format(
|
| 165 |
+
category_id
|
| 166 |
+
)
|
| 167 |
+
result["category_id"] = reverse_id_mapping[category_id]
|
| 168 |
+
|
| 169 |
+
if self._output_dir:
|
| 170 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
| 171 |
+
self._logger.info("Saving results to {}".format(file_path))
|
| 172 |
+
with PathManager.open(file_path, "w") as f:
|
| 173 |
+
f.write(json.dumps(coco_results))
|
| 174 |
+
f.flush()
|
| 175 |
+
|
| 176 |
+
if not self._do_evaluation:
|
| 177 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
self._logger.info("Evaluating predictions ...")
|
| 181 |
+
for task in sorted(tasks):
|
| 182 |
+
coco_eval = (
|
| 183 |
+
_evaluate_predictions_on_coco(
|
| 184 |
+
self._coco_api, coco_results, task, kpt_oks_sigmas=self._kpt_oks_sigmas
|
| 185 |
+
)
|
| 186 |
+
if len(coco_results) > 0
|
| 187 |
+
else None # cocoapi does not handle empty results very well
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
res = self._derive_coco_results(
|
| 191 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
| 192 |
+
)
|
| 193 |
+
self._results[task] = res
|
| 194 |
+
|
| 195 |
+
def _eval_box_proposals(self, predictions):
|
| 196 |
+
"""
|
| 197 |
+
Evaluate the box proposals in predictions.
|
| 198 |
+
Fill self._results with the metrics for "box_proposals" task.
|
| 199 |
+
"""
|
| 200 |
+
if self._output_dir:
|
| 201 |
+
# Saving generated box proposals to file.
|
| 202 |
+
# Predicted box_proposals are in XYXY_ABS mode.
|
| 203 |
+
bbox_mode = BoxMode.XYXY_ABS.value
|
| 204 |
+
ids, boxes, objectness_logits = [], [], []
|
| 205 |
+
for prediction in predictions:
|
| 206 |
+
ids.append(prediction["image_id"])
|
| 207 |
+
boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
|
| 208 |
+
objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
|
| 209 |
+
|
| 210 |
+
proposal_data = {
|
| 211 |
+
"boxes": boxes,
|
| 212 |
+
"objectness_logits": objectness_logits,
|
| 213 |
+
"ids": ids,
|
| 214 |
+
"bbox_mode": bbox_mode,
|
| 215 |
+
}
|
| 216 |
+
with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
|
| 217 |
+
pickle.dump(proposal_data, f)
|
| 218 |
+
|
| 219 |
+
if not self._do_evaluation:
|
| 220 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
self._logger.info("Evaluating bbox proposals ...")
|
| 224 |
+
res = {}
|
| 225 |
+
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
|
| 226 |
+
for limit in [100, 1000]:
|
| 227 |
+
for area, suffix in areas.items():
|
| 228 |
+
stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
|
| 229 |
+
key = "AR{}@{:d}".format(suffix, limit)
|
| 230 |
+
res[key] = float(stats["ar"].item() * 100)
|
| 231 |
+
self._logger.info("Proposal metrics: \n" + create_small_table(res))
|
| 232 |
+
self._results["box_proposals"] = res
|
| 233 |
+
|
| 234 |
+
def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
|
| 235 |
+
"""
|
| 236 |
+
Derive the desired score numbers from summarized COCOeval.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
coco_eval (None or COCOEval): None represents no predictions from model.
|
| 240 |
+
iou_type (str):
|
| 241 |
+
class_names (None or list[str]): if provided, will use it to predict
|
| 242 |
+
per-category AP.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
a dict of {metric name: score}
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
metrics = {
|
| 249 |
+
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
| 250 |
+
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
|
| 251 |
+
"keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
|
| 252 |
+
}[iou_type]
|
| 253 |
+
|
| 254 |
+
if coco_eval is None:
|
| 255 |
+
self._logger.warn("No predictions from the model!")
|
| 256 |
+
return {metric: float("nan") for metric in metrics}
|
| 257 |
+
|
| 258 |
+
# the standard metrics
|
| 259 |
+
results = {
|
| 260 |
+
metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
|
| 261 |
+
for idx, metric in enumerate(metrics)
|
| 262 |
+
}
|
| 263 |
+
self._logger.info(
|
| 264 |
+
"Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
|
| 265 |
+
)
|
| 266 |
+
if not np.isfinite(sum(results.values())):
|
| 267 |
+
self._logger.info("Note that some metrics cannot be computed.")
|
| 268 |
+
|
| 269 |
+
if class_names is None or len(class_names) <= 1:
|
| 270 |
+
return results
|
| 271 |
+
# Compute per-category AP
|
| 272 |
+
# from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
|
| 273 |
+
precisions = coco_eval.eval["precision"]
|
| 274 |
+
# precision has dims (iou, recall, cls, area range, max dets)
|
| 275 |
+
assert len(class_names) == precisions.shape[2]
|
| 276 |
+
|
| 277 |
+
results_per_category = []
|
| 278 |
+
for idx, name in enumerate(class_names):
|
| 279 |
+
# area range index 0: all area ranges
|
| 280 |
+
# max dets index -1: typically 100 per image
|
| 281 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 282 |
+
precision = precision[precision > -1]
|
| 283 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
| 284 |
+
results_per_category.append(("{}".format(name), float(ap * 100)))
|
| 285 |
+
|
| 286 |
+
# tabulate it
|
| 287 |
+
N_COLS = min(6, len(results_per_category) * 2)
|
| 288 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 289 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
|
| 290 |
+
table = tabulate(
|
| 291 |
+
results_2d,
|
| 292 |
+
tablefmt="pipe",
|
| 293 |
+
floatfmt=".3f",
|
| 294 |
+
headers=["category", "AP"] * (N_COLS // 2),
|
| 295 |
+
numalign="left",
|
| 296 |
+
)
|
| 297 |
+
self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
|
| 298 |
+
|
| 299 |
+
results.update({"AP-" + name: ap for name, ap in results_per_category})
|
| 300 |
+
return results
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def instances_to_coco_json(instances, img_id):
|
| 304 |
+
"""
|
| 305 |
+
Dump an "Instances" object to a COCO-format json that's used for evaluation.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
instances (Instances):
|
| 309 |
+
img_id (int): the image id
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
list[dict]: list of json annotations in COCO format.
|
| 313 |
+
"""
|
| 314 |
+
num_instance = len(instances)
|
| 315 |
+
if num_instance == 0:
|
| 316 |
+
return []
|
| 317 |
+
|
| 318 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
| 319 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 320 |
+
boxes = boxes.tolist()
|
| 321 |
+
scores = instances.scores.tolist()
|
| 322 |
+
classes = instances.pred_classes.tolist()
|
| 323 |
+
|
| 324 |
+
has_mask = instances.has("pred_masks")
|
| 325 |
+
if has_mask:
|
| 326 |
+
# use RLE to encode the masks, because they are too large and takes memory
|
| 327 |
+
# since this evaluator stores outputs of the entire dataset
|
| 328 |
+
rles = [
|
| 329 |
+
mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
|
| 330 |
+
for mask in instances.pred_masks
|
| 331 |
+
]
|
| 332 |
+
for rle in rles:
|
| 333 |
+
# "counts" is an array encoded by mask_util as a byte-stream. Python3's
|
| 334 |
+
# json writer which always produces strings cannot serialize a bytestream
|
| 335 |
+
# unless you decode it. Thankfully, utf-8 works out (which is also what
|
| 336 |
+
# the pycocotools/_mask.pyx does).
|
| 337 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 338 |
+
|
| 339 |
+
has_keypoints = instances.has("pred_keypoints")
|
| 340 |
+
if has_keypoints:
|
| 341 |
+
keypoints = instances.pred_keypoints
|
| 342 |
+
|
| 343 |
+
results = []
|
| 344 |
+
for k in range(num_instance):
|
| 345 |
+
result = {
|
| 346 |
+
"image_id": img_id,
|
| 347 |
+
"category_id": classes[k],
|
| 348 |
+
"bbox": boxes[k],
|
| 349 |
+
"score": scores[k],
|
| 350 |
+
}
|
| 351 |
+
if has_mask:
|
| 352 |
+
result["segmentation"] = rles[k]
|
| 353 |
+
if has_keypoints:
|
| 354 |
+
# In COCO annotations,
|
| 355 |
+
# keypoints coordinates are pixel indices.
|
| 356 |
+
# However our predictions are floating point coordinates.
|
| 357 |
+
# Therefore we subtract 0.5 to be consistent with the annotation format.
|
| 358 |
+
# This is the inverse of data loading logic in `data/coco.py`.
|
| 359 |
+
keypoints[k][:, :2] -= 0.5
|
| 360 |
+
result["keypoints"] = keypoints[k].flatten().tolist()
|
| 361 |
+
results.append(result)
|
| 362 |
+
return results
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# inspired from Detectron:
|
| 366 |
+
# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
|
| 367 |
+
def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
|
| 368 |
+
"""
|
| 369 |
+
Evaluate detection proposal recall metrics. This function is a much
|
| 370 |
+
faster alternative to the official COCO API recall evaluation code. However,
|
| 371 |
+
it produces slightly different results.
|
| 372 |
+
"""
|
| 373 |
+
# Record max overlap value for each gt box
|
| 374 |
+
# Return vector of overlap values
|
| 375 |
+
areas = {
|
| 376 |
+
"all": 0,
|
| 377 |
+
"small": 1,
|
| 378 |
+
"medium": 2,
|
| 379 |
+
"large": 3,
|
| 380 |
+
"96-128": 4,
|
| 381 |
+
"128-256": 5,
|
| 382 |
+
"256-512": 6,
|
| 383 |
+
"512-inf": 7,
|
| 384 |
+
}
|
| 385 |
+
area_ranges = [
|
| 386 |
+
[0 ** 2, 1e5 ** 2], # all
|
| 387 |
+
[0 ** 2, 32 ** 2], # small
|
| 388 |
+
[32 ** 2, 96 ** 2], # medium
|
| 389 |
+
[96 ** 2, 1e5 ** 2], # large
|
| 390 |
+
[96 ** 2, 128 ** 2], # 96-128
|
| 391 |
+
[128 ** 2, 256 ** 2], # 128-256
|
| 392 |
+
[256 ** 2, 512 ** 2], # 256-512
|
| 393 |
+
[512 ** 2, 1e5 ** 2],
|
| 394 |
+
] # 512-inf
|
| 395 |
+
assert area in areas, "Unknown area range: {}".format(area)
|
| 396 |
+
area_range = area_ranges[areas[area]]
|
| 397 |
+
gt_overlaps = []
|
| 398 |
+
num_pos = 0
|
| 399 |
+
|
| 400 |
+
for prediction_dict in dataset_predictions:
|
| 401 |
+
predictions = prediction_dict["proposals"]
|
| 402 |
+
|
| 403 |
+
# sort predictions in descending order
|
| 404 |
+
# TODO maybe remove this and make it explicit in the documentation
|
| 405 |
+
inds = predictions.objectness_logits.sort(descending=True)[1]
|
| 406 |
+
predictions = predictions[inds]
|
| 407 |
+
|
| 408 |
+
ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
|
| 409 |
+
anno = coco_api.loadAnns(ann_ids)
|
| 410 |
+
gt_boxes = [
|
| 411 |
+
BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
| 412 |
+
for obj in anno
|
| 413 |
+
if obj["iscrowd"] == 0
|
| 414 |
+
]
|
| 415 |
+
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
|
| 416 |
+
gt_boxes = Boxes(gt_boxes)
|
| 417 |
+
gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
|
| 418 |
+
|
| 419 |
+
if len(gt_boxes) == 0 or len(predictions) == 0:
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
|
| 423 |
+
gt_boxes = gt_boxes[valid_gt_inds]
|
| 424 |
+
|
| 425 |
+
num_pos += len(gt_boxes)
|
| 426 |
+
|
| 427 |
+
if len(gt_boxes) == 0:
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
if limit is not None and len(predictions) > limit:
|
| 431 |
+
predictions = predictions[:limit]
|
| 432 |
+
|
| 433 |
+
overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
|
| 434 |
+
|
| 435 |
+
_gt_overlaps = torch.zeros(len(gt_boxes))
|
| 436 |
+
for j in range(min(len(predictions), len(gt_boxes))):
|
| 437 |
+
# find which proposal box maximally covers each gt box
|
| 438 |
+
# and get the iou amount of coverage for each gt box
|
| 439 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
| 440 |
+
|
| 441 |
+
# find which gt box is 'best' covered (i.e. 'best' = most iou)
|
| 442 |
+
gt_ovr, gt_ind = max_overlaps.max(dim=0)
|
| 443 |
+
assert gt_ovr >= 0
|
| 444 |
+
# find the proposal box that covers the best covered gt box
|
| 445 |
+
box_ind = argmax_overlaps[gt_ind]
|
| 446 |
+
# record the iou coverage of this gt box
|
| 447 |
+
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
|
| 448 |
+
assert _gt_overlaps[j] == gt_ovr
|
| 449 |
+
# mark the proposal box and the gt box as used
|
| 450 |
+
overlaps[box_ind, :] = -1
|
| 451 |
+
overlaps[:, gt_ind] = -1
|
| 452 |
+
|
| 453 |
+
# append recorded iou coverage level
|
| 454 |
+
gt_overlaps.append(_gt_overlaps)
|
| 455 |
+
gt_overlaps = (
|
| 456 |
+
torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
|
| 457 |
+
)
|
| 458 |
+
gt_overlaps, _ = torch.sort(gt_overlaps)
|
| 459 |
+
|
| 460 |
+
if thresholds is None:
|
| 461 |
+
step = 0.05
|
| 462 |
+
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
|
| 463 |
+
recalls = torch.zeros_like(thresholds)
|
| 464 |
+
# compute recall for each iou threshold
|
| 465 |
+
for i, t in enumerate(thresholds):
|
| 466 |
+
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
|
| 467 |
+
# ar = 2 * np.trapz(recalls, thresholds)
|
| 468 |
+
ar = recalls.mean()
|
| 469 |
+
return {
|
| 470 |
+
"ar": ar,
|
| 471 |
+
"recalls": recalls,
|
| 472 |
+
"thresholds": thresholds,
|
| 473 |
+
"gt_overlaps": gt_overlaps,
|
| 474 |
+
"num_pos": num_pos,
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _evaluate_predictions_on_coco(coco_gt, coco_results, iou_type, kpt_oks_sigmas=None):
|
| 479 |
+
"""
|
| 480 |
+
Evaluate the coco results using COCOEval API.
|
| 481 |
+
"""
|
| 482 |
+
assert len(coco_results) > 0
|
| 483 |
+
|
| 484 |
+
if iou_type == "segm":
|
| 485 |
+
coco_results = copy.deepcopy(coco_results)
|
| 486 |
+
# When evaluating mask AP, if the results contain bbox, cocoapi will
|
| 487 |
+
# use the box area as the area of the instance, instead of the mask area.
|
| 488 |
+
# This leads to a different definition of small/medium/large.
|
| 489 |
+
# We remove the bbox field to let mask AP use mask area.
|
| 490 |
+
for c in coco_results:
|
| 491 |
+
c.pop("bbox", None)
|
| 492 |
+
|
| 493 |
+
coco_dt = coco_gt.loadRes(coco_results)
|
| 494 |
+
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
| 495 |
+
# Use the COCO default keypoint OKS sigmas unless overrides are specified
|
| 496 |
+
if kpt_oks_sigmas:
|
| 497 |
+
coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
|
| 498 |
+
|
| 499 |
+
if iou_type == "keypoints":
|
| 500 |
+
num_keypoints = len(coco_results[0]["keypoints"]) // 3
|
| 501 |
+
assert len(coco_eval.params.kpt_oks_sigmas) == num_keypoints, (
|
| 502 |
+
"[COCOEvaluator] The length of cfg.TEST.KEYPOINT_OKS_SIGMAS (default: 17) "
|
| 503 |
+
"must be equal to the number of keypoints. However the prediction has {} "
|
| 504 |
+
"keypoints! For more information please refer to "
|
| 505 |
+
"http://cocodataset.org/#keypoints-eval.".format(num_keypoints)
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
coco_eval.evaluate()
|
| 509 |
+
coco_eval.accumulate()
|
| 510 |
+
coco_eval.summarize()
|
| 511 |
+
|
| 512 |
+
return coco_eval
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/evaluator.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from detectron2.utils.comm import get_world_size, is_main_process
|
| 10 |
+
from detectron2.utils.logger import log_every_n_seconds
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DatasetEvaluator:
|
| 14 |
+
"""
|
| 15 |
+
Base class for a dataset evaluator.
|
| 16 |
+
|
| 17 |
+
The function :func:`inference_on_dataset` runs the model over
|
| 18 |
+
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
|
| 19 |
+
|
| 20 |
+
This class will accumulate information of the inputs/outputs (by :meth:`process`),
|
| 21 |
+
and produce evaluation results in the end (by :meth:`evaluate`).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def reset(self):
|
| 25 |
+
"""
|
| 26 |
+
Preparation for a new round of evaluation.
|
| 27 |
+
Should be called before starting a round of evaluation.
|
| 28 |
+
"""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
def process(self, inputs, outputs):
|
| 32 |
+
"""
|
| 33 |
+
Process the pair of inputs and outputs.
|
| 34 |
+
If they contain batches, the pairs can be consumed one-by-one using `zip`:
|
| 35 |
+
|
| 36 |
+
.. code-block:: python
|
| 37 |
+
|
| 38 |
+
for input_, output in zip(inputs, outputs):
|
| 39 |
+
# do evaluation on single input/output pair
|
| 40 |
+
...
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
inputs (list): the inputs that's used to call the model.
|
| 44 |
+
outputs (list): the return value of `model(inputs)`
|
| 45 |
+
"""
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
def evaluate(self):
|
| 49 |
+
"""
|
| 50 |
+
Evaluate/summarize the performance, after processing all input/output pairs.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
dict:
|
| 54 |
+
A new evaluator class can return a dict of arbitrary format
|
| 55 |
+
as long as the user can process the results.
|
| 56 |
+
In our train_net.py, we expect the following format:
|
| 57 |
+
|
| 58 |
+
* key: the name of the task (e.g., bbox)
|
| 59 |
+
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
|
| 60 |
+
"""
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DatasetEvaluators(DatasetEvaluator):
|
| 65 |
+
"""
|
| 66 |
+
Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
|
| 67 |
+
|
| 68 |
+
This class dispatches every evaluation call to
|
| 69 |
+
all of its :class:`DatasetEvaluator`.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, evaluators):
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
evaluators (list): the evaluators to combine.
|
| 76 |
+
"""
|
| 77 |
+
super().__init__()
|
| 78 |
+
self._evaluators = evaluators
|
| 79 |
+
|
| 80 |
+
def reset(self):
|
| 81 |
+
for evaluator in self._evaluators:
|
| 82 |
+
evaluator.reset()
|
| 83 |
+
|
| 84 |
+
def process(self, inputs, outputs):
|
| 85 |
+
for evaluator in self._evaluators:
|
| 86 |
+
evaluator.process(inputs, outputs)
|
| 87 |
+
|
| 88 |
+
def evaluate(self):
|
| 89 |
+
results = OrderedDict()
|
| 90 |
+
for evaluator in self._evaluators:
|
| 91 |
+
result = evaluator.evaluate()
|
| 92 |
+
if is_main_process() and result is not None:
|
| 93 |
+
for k, v in result.items():
|
| 94 |
+
assert (
|
| 95 |
+
k not in results
|
| 96 |
+
), "Different evaluators produce results with the same key {}".format(k)
|
| 97 |
+
results[k] = v
|
| 98 |
+
return results
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def inference_on_dataset(model, data_loader, evaluator):
|
| 102 |
+
"""
|
| 103 |
+
Run model on the data_loader and evaluate the metrics with evaluator.
|
| 104 |
+
Also benchmark the inference speed of `model.forward` accurately.
|
| 105 |
+
The model will be used in eval mode.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
model (nn.Module): a module which accepts an object from
|
| 109 |
+
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
|
| 110 |
+
|
| 111 |
+
If you wish to evaluate a model in `training` mode instead, you can
|
| 112 |
+
wrap the given model and override its behavior of `.eval()` and `.train()`.
|
| 113 |
+
data_loader: an iterable object with a length.
|
| 114 |
+
The elements it generates will be the inputs to the model.
|
| 115 |
+
evaluator (DatasetEvaluator): the evaluator to run. Use `None` if you only want
|
| 116 |
+
to benchmark, but don't want to do any evaluation.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
The return value of `evaluator.evaluate()`
|
| 120 |
+
"""
|
| 121 |
+
num_devices = get_world_size()
|
| 122 |
+
logger = logging.getLogger(__name__)
|
| 123 |
+
logger.info("Start inference on {} images".format(len(data_loader)))
|
| 124 |
+
|
| 125 |
+
total = len(data_loader) # inference data loader must have a fixed length
|
| 126 |
+
if evaluator is None:
|
| 127 |
+
# create a no-op evaluator
|
| 128 |
+
evaluator = DatasetEvaluators([])
|
| 129 |
+
evaluator.reset()
|
| 130 |
+
|
| 131 |
+
num_warmup = min(5, total - 1)
|
| 132 |
+
start_time = time.perf_counter()
|
| 133 |
+
total_compute_time = 0
|
| 134 |
+
with inference_context(model), torch.no_grad():
|
| 135 |
+
for idx, inputs in enumerate(data_loader):
|
| 136 |
+
if idx == num_warmup:
|
| 137 |
+
start_time = time.perf_counter()
|
| 138 |
+
total_compute_time = 0
|
| 139 |
+
|
| 140 |
+
start_compute_time = time.perf_counter()
|
| 141 |
+
outputs = model(inputs)
|
| 142 |
+
if torch.cuda.is_available():
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
total_compute_time += time.perf_counter() - start_compute_time
|
| 145 |
+
evaluator.process(inputs, outputs)
|
| 146 |
+
|
| 147 |
+
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
| 148 |
+
seconds_per_img = total_compute_time / iters_after_start
|
| 149 |
+
if idx >= num_warmup * 2 or seconds_per_img > 5:
|
| 150 |
+
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
|
| 151 |
+
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
|
| 152 |
+
log_every_n_seconds(
|
| 153 |
+
logging.INFO,
|
| 154 |
+
"Inference done {}/{}. {:.4f} s / demo. ETA={}".format(
|
| 155 |
+
idx + 1, total, seconds_per_img, str(eta)
|
| 156 |
+
),
|
| 157 |
+
n=5,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Measure the time only for this worker (before the synchronization barrier)
|
| 161 |
+
total_time = time.perf_counter() - start_time
|
| 162 |
+
total_time_str = str(datetime.timedelta(seconds=total_time))
|
| 163 |
+
# NOTE this format is parsed by grep
|
| 164 |
+
logger.info(
|
| 165 |
+
"Total inference time: {} ({:.6f} s / demo per device, on {} devices)".format(
|
| 166 |
+
total_time_str, total_time / (total - num_warmup), num_devices
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
| 170 |
+
logger.info(
|
| 171 |
+
"Total inference pure compute time: {} ({:.6f} s / demo per device, on {} devices)".format(
|
| 172 |
+
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
results = evaluator.evaluate()
|
| 177 |
+
# An evaluator may return None when not in main process.
|
| 178 |
+
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
| 179 |
+
if results is None:
|
| 180 |
+
results = {}
|
| 181 |
+
return results
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@contextmanager
|
| 185 |
+
def inference_context(model):
|
| 186 |
+
"""
|
| 187 |
+
A context where the model is temporarily changed to eval mode,
|
| 188 |
+
and restored to previous mode afterwards.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
model: a torch Module
|
| 192 |
+
"""
|
| 193 |
+
training_mode = model.training
|
| 194 |
+
model.eval()
|
| 195 |
+
yield
|
| 196 |
+
model.train(training_mode)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/lvis_evaluation.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import copy
|
| 3 |
+
import itertools
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
import torch
|
| 10 |
+
from fvcore.common.file_io import PathManager
|
| 11 |
+
|
| 12 |
+
import detectron2.utils.comm as comm
|
| 13 |
+
from detectron2.data import MetadataCatalog
|
| 14 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
| 15 |
+
from detectron2.utils.logger import create_small_table
|
| 16 |
+
|
| 17 |
+
from .coco_evaluation import instances_to_coco_json
|
| 18 |
+
from .evaluator import DatasetEvaluator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LVISEvaluator(DatasetEvaluator):
|
| 22 |
+
"""
|
| 23 |
+
Evaluate object proposal and instance detection/segmentation outputs using
|
| 24 |
+
LVIS's metrics and evaluation API.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, dataset_name, cfg, distributed, output_dir=None):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 31 |
+
It must have the following corresponding metadata:
|
| 32 |
+
"json_file": the path to the LVIS format annotation
|
| 33 |
+
cfg (CfgNode): config instance
|
| 34 |
+
distributed (True): if True, will collect results from all ranks for evaluation.
|
| 35 |
+
Otherwise, will evaluate the results in the current process.
|
| 36 |
+
output_dir (str): optional, an output directory to dump results.
|
| 37 |
+
"""
|
| 38 |
+
from lvis import LVIS
|
| 39 |
+
|
| 40 |
+
self._tasks = self._tasks_from_config(cfg)
|
| 41 |
+
self._distributed = distributed
|
| 42 |
+
self._output_dir = output_dir
|
| 43 |
+
|
| 44 |
+
self._cpu_device = torch.device("cpu")
|
| 45 |
+
self._logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 48 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
| 49 |
+
self._lvis_api = LVIS(json_file)
|
| 50 |
+
# Test set json files do not contain annotations (evaluation must be
|
| 51 |
+
# performed using the LVIS evaluation server).
|
| 52 |
+
self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0
|
| 53 |
+
|
| 54 |
+
def reset(self):
|
| 55 |
+
self._predictions = []
|
| 56 |
+
|
| 57 |
+
def _tasks_from_config(self, cfg):
|
| 58 |
+
"""
|
| 59 |
+
Returns:
|
| 60 |
+
tuple[str]: tasks that can be evaluated under the given configuration.
|
| 61 |
+
"""
|
| 62 |
+
tasks = ("bbox",)
|
| 63 |
+
if cfg.MODEL.MASK_ON:
|
| 64 |
+
tasks = tasks + ("segm",)
|
| 65 |
+
return tasks
|
| 66 |
+
|
| 67 |
+
def process(self, inputs, outputs):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN).
|
| 71 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 72 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 73 |
+
outputs: the outputs of a LVIS model. It is a list of dicts with key
|
| 74 |
+
"instances" that contains :class:`Instances`.
|
| 75 |
+
"""
|
| 76 |
+
for input, output in zip(inputs, outputs):
|
| 77 |
+
prediction = {"image_id": input["image_id"]}
|
| 78 |
+
|
| 79 |
+
if "instances" in output:
|
| 80 |
+
instances = output["instances"].to(self._cpu_device)
|
| 81 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
| 82 |
+
if "proposals" in output:
|
| 83 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
| 84 |
+
self._predictions.append(prediction)
|
| 85 |
+
|
| 86 |
+
def evaluate(self):
|
| 87 |
+
if self._distributed:
|
| 88 |
+
comm.synchronize()
|
| 89 |
+
predictions = comm.gather(self._predictions, dst=0)
|
| 90 |
+
predictions = list(itertools.chain(*predictions))
|
| 91 |
+
|
| 92 |
+
if not comm.is_main_process():
|
| 93 |
+
return
|
| 94 |
+
else:
|
| 95 |
+
predictions = self._predictions
|
| 96 |
+
|
| 97 |
+
if len(predictions) == 0:
|
| 98 |
+
self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
|
| 99 |
+
return {}
|
| 100 |
+
|
| 101 |
+
if self._output_dir:
|
| 102 |
+
PathManager.mkdirs(self._output_dir)
|
| 103 |
+
file_path = os.path.join(self._output_dir, "instances_predictions.pth")
|
| 104 |
+
with PathManager.open(file_path, "wb") as f:
|
| 105 |
+
torch.save(predictions, f)
|
| 106 |
+
|
| 107 |
+
self._results = OrderedDict()
|
| 108 |
+
if "proposals" in predictions[0]:
|
| 109 |
+
self._eval_box_proposals(predictions)
|
| 110 |
+
if "instances" in predictions[0]:
|
| 111 |
+
self._eval_predictions(set(self._tasks), predictions)
|
| 112 |
+
# Copy so the caller can do whatever with results
|
| 113 |
+
return copy.deepcopy(self._results)
|
| 114 |
+
|
| 115 |
+
def _eval_predictions(self, tasks, predictions):
|
| 116 |
+
"""
|
| 117 |
+
Evaluate predictions on the given tasks.
|
| 118 |
+
Fill self._results with the metrics of the tasks.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
predictions (list[dict]): list of outputs from the model
|
| 122 |
+
"""
|
| 123 |
+
self._logger.info("Preparing results in the LVIS format ...")
|
| 124 |
+
lvis_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
| 125 |
+
|
| 126 |
+
# LVIS evaluator can be used to evaluate results for COCO dataset categories.
|
| 127 |
+
# In this case `_metadata` variable will have a field with COCO-specific category mapping.
|
| 128 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
| 129 |
+
reverse_id_mapping = {
|
| 130 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
| 131 |
+
}
|
| 132 |
+
for result in lvis_results:
|
| 133 |
+
result["category_id"] = reverse_id_mapping[result["category_id"]]
|
| 134 |
+
else:
|
| 135 |
+
# unmap the category ids for LVIS (from 0-indexed to 1-indexed)
|
| 136 |
+
for result in lvis_results:
|
| 137 |
+
result["category_id"] += 1
|
| 138 |
+
|
| 139 |
+
if self._output_dir:
|
| 140 |
+
file_path = os.path.join(self._output_dir, "lvis_instances_results.json")
|
| 141 |
+
self._logger.info("Saving results to {}".format(file_path))
|
| 142 |
+
with PathManager.open(file_path, "w") as f:
|
| 143 |
+
f.write(json.dumps(lvis_results))
|
| 144 |
+
f.flush()
|
| 145 |
+
|
| 146 |
+
if not self._do_evaluation:
|
| 147 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
self._logger.info("Evaluating predictions ...")
|
| 151 |
+
for task in sorted(tasks):
|
| 152 |
+
res = _evaluate_predictions_on_lvis(
|
| 153 |
+
self._lvis_api, lvis_results, task, class_names=self._metadata.get("thing_classes")
|
| 154 |
+
)
|
| 155 |
+
self._results[task] = res
|
| 156 |
+
|
| 157 |
+
def _eval_box_proposals(self, predictions):
|
| 158 |
+
"""
|
| 159 |
+
Evaluate the box proposals in predictions.
|
| 160 |
+
Fill self._results with the metrics for "box_proposals" task.
|
| 161 |
+
"""
|
| 162 |
+
if self._output_dir:
|
| 163 |
+
# Saving generated box proposals to file.
|
| 164 |
+
# Predicted box_proposals are in XYXY_ABS mode.
|
| 165 |
+
bbox_mode = BoxMode.XYXY_ABS.value
|
| 166 |
+
ids, boxes, objectness_logits = [], [], []
|
| 167 |
+
for prediction in predictions:
|
| 168 |
+
ids.append(prediction["image_id"])
|
| 169 |
+
boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
|
| 170 |
+
objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
|
| 171 |
+
|
| 172 |
+
proposal_data = {
|
| 173 |
+
"boxes": boxes,
|
| 174 |
+
"objectness_logits": objectness_logits,
|
| 175 |
+
"ids": ids,
|
| 176 |
+
"bbox_mode": bbox_mode,
|
| 177 |
+
}
|
| 178 |
+
with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
|
| 179 |
+
pickle.dump(proposal_data, f)
|
| 180 |
+
|
| 181 |
+
if not self._do_evaluation:
|
| 182 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
self._logger.info("Evaluating bbox proposals ...")
|
| 186 |
+
res = {}
|
| 187 |
+
areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
|
| 188 |
+
for limit in [100, 1000]:
|
| 189 |
+
for area, suffix in areas.items():
|
| 190 |
+
stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit)
|
| 191 |
+
key = "AR{}@{:d}".format(suffix, limit)
|
| 192 |
+
res[key] = float(stats["ar"].item() * 100)
|
| 193 |
+
self._logger.info("Proposal metrics: \n" + create_small_table(res))
|
| 194 |
+
self._results["box_proposals"] = res
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# inspired from Detectron:
|
| 198 |
+
# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
|
| 199 |
+
def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None):
|
| 200 |
+
"""
|
| 201 |
+
Evaluate detection proposal recall metrics. This function is a much
|
| 202 |
+
faster alternative to the official LVIS API recall evaluation code. However,
|
| 203 |
+
it produces slightly different results.
|
| 204 |
+
"""
|
| 205 |
+
# Record max overlap value for each gt box
|
| 206 |
+
# Return vector of overlap values
|
| 207 |
+
areas = {
|
| 208 |
+
"all": 0,
|
| 209 |
+
"small": 1,
|
| 210 |
+
"medium": 2,
|
| 211 |
+
"large": 3,
|
| 212 |
+
"96-128": 4,
|
| 213 |
+
"128-256": 5,
|
| 214 |
+
"256-512": 6,
|
| 215 |
+
"512-inf": 7,
|
| 216 |
+
}
|
| 217 |
+
area_ranges = [
|
| 218 |
+
[0 ** 2, 1e5 ** 2], # all
|
| 219 |
+
[0 ** 2, 32 ** 2], # small
|
| 220 |
+
[32 ** 2, 96 ** 2], # medium
|
| 221 |
+
[96 ** 2, 1e5 ** 2], # large
|
| 222 |
+
[96 ** 2, 128 ** 2], # 96-128
|
| 223 |
+
[128 ** 2, 256 ** 2], # 128-256
|
| 224 |
+
[256 ** 2, 512 ** 2], # 256-512
|
| 225 |
+
[512 ** 2, 1e5 ** 2],
|
| 226 |
+
] # 512-inf
|
| 227 |
+
assert area in areas, "Unknown area range: {}".format(area)
|
| 228 |
+
area_range = area_ranges[areas[area]]
|
| 229 |
+
gt_overlaps = []
|
| 230 |
+
num_pos = 0
|
| 231 |
+
|
| 232 |
+
for prediction_dict in dataset_predictions:
|
| 233 |
+
predictions = prediction_dict["proposals"]
|
| 234 |
+
|
| 235 |
+
# sort predictions in descending order
|
| 236 |
+
# TODO maybe remove this and make it explicit in the documentation
|
| 237 |
+
inds = predictions.objectness_logits.sort(descending=True)[1]
|
| 238 |
+
predictions = predictions[inds]
|
| 239 |
+
|
| 240 |
+
ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]])
|
| 241 |
+
anno = lvis_api.load_anns(ann_ids)
|
| 242 |
+
gt_boxes = [
|
| 243 |
+
BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno
|
| 244 |
+
]
|
| 245 |
+
gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
|
| 246 |
+
gt_boxes = Boxes(gt_boxes)
|
| 247 |
+
gt_areas = torch.as_tensor([obj["area"] for obj in anno])
|
| 248 |
+
|
| 249 |
+
if len(gt_boxes) == 0 or len(predictions) == 0:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
|
| 253 |
+
gt_boxes = gt_boxes[valid_gt_inds]
|
| 254 |
+
|
| 255 |
+
num_pos += len(gt_boxes)
|
| 256 |
+
|
| 257 |
+
if len(gt_boxes) == 0:
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
if limit is not None and len(predictions) > limit:
|
| 261 |
+
predictions = predictions[:limit]
|
| 262 |
+
|
| 263 |
+
overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
|
| 264 |
+
|
| 265 |
+
_gt_overlaps = torch.zeros(len(gt_boxes))
|
| 266 |
+
for j in range(min(len(predictions), len(gt_boxes))):
|
| 267 |
+
# find which proposal box maximally covers each gt box
|
| 268 |
+
# and get the iou amount of coverage for each gt box
|
| 269 |
+
max_overlaps, argmax_overlaps = overlaps.max(dim=0)
|
| 270 |
+
|
| 271 |
+
# find which gt box is 'best' covered (i.e. 'best' = most iou)
|
| 272 |
+
gt_ovr, gt_ind = max_overlaps.max(dim=0)
|
| 273 |
+
assert gt_ovr >= 0
|
| 274 |
+
# find the proposal box that covers the best covered gt box
|
| 275 |
+
box_ind = argmax_overlaps[gt_ind]
|
| 276 |
+
# record the iou coverage of this gt box
|
| 277 |
+
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
|
| 278 |
+
assert _gt_overlaps[j] == gt_ovr
|
| 279 |
+
# mark the proposal box and the gt box as used
|
| 280 |
+
overlaps[box_ind, :] = -1
|
| 281 |
+
overlaps[:, gt_ind] = -1
|
| 282 |
+
|
| 283 |
+
# append recorded iou coverage level
|
| 284 |
+
gt_overlaps.append(_gt_overlaps)
|
| 285 |
+
gt_overlaps = (
|
| 286 |
+
torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
|
| 287 |
+
)
|
| 288 |
+
gt_overlaps, _ = torch.sort(gt_overlaps)
|
| 289 |
+
|
| 290 |
+
if thresholds is None:
|
| 291 |
+
step = 0.05
|
| 292 |
+
thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
|
| 293 |
+
recalls = torch.zeros_like(thresholds)
|
| 294 |
+
# compute recall for each iou threshold
|
| 295 |
+
for i, t in enumerate(thresholds):
|
| 296 |
+
recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
|
| 297 |
+
# ar = 2 * np.trapz(recalls, thresholds)
|
| 298 |
+
ar = recalls.mean()
|
| 299 |
+
return {
|
| 300 |
+
"ar": ar,
|
| 301 |
+
"recalls": recalls,
|
| 302 |
+
"thresholds": thresholds,
|
| 303 |
+
"gt_overlaps": gt_overlaps,
|
| 304 |
+
"num_pos": num_pos,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _evaluate_predictions_on_lvis(lvis_gt, lvis_results, iou_type, class_names=None):
|
| 309 |
+
"""
|
| 310 |
+
Args:
|
| 311 |
+
iou_type (str):
|
| 312 |
+
kpt_oks_sigmas (list[float]):
|
| 313 |
+
class_names (None or list[str]): if provided, will use it to predict
|
| 314 |
+
per-category AP.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
a dict of {metric name: score}
|
| 318 |
+
"""
|
| 319 |
+
metrics = {
|
| 320 |
+
"bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
|
| 321 |
+
"segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
|
| 322 |
+
}[iou_type]
|
| 323 |
+
|
| 324 |
+
logger = logging.getLogger(__name__)
|
| 325 |
+
|
| 326 |
+
if len(lvis_results) == 0: # TODO: check if needed
|
| 327 |
+
logger.warn("No predictions from the model!")
|
| 328 |
+
return {metric: float("nan") for metric in metrics}
|
| 329 |
+
|
| 330 |
+
if iou_type == "segm":
|
| 331 |
+
lvis_results = copy.deepcopy(lvis_results)
|
| 332 |
+
# When evaluating mask AP, if the results contain bbox, LVIS API will
|
| 333 |
+
# use the box area as the area of the instance, instead of the mask area.
|
| 334 |
+
# This leads to a different definition of small/medium/large.
|
| 335 |
+
# We remove the bbox field to let mask AP use mask area.
|
| 336 |
+
for c in lvis_results:
|
| 337 |
+
c.pop("bbox", None)
|
| 338 |
+
|
| 339 |
+
from lvis import LVISEval, LVISResults
|
| 340 |
+
|
| 341 |
+
lvis_results = LVISResults(lvis_gt, lvis_results)
|
| 342 |
+
lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type)
|
| 343 |
+
lvis_eval.run()
|
| 344 |
+
lvis_eval.print_results()
|
| 345 |
+
|
| 346 |
+
# Pull the standard metrics from the LVIS results
|
| 347 |
+
results = lvis_eval.get_results()
|
| 348 |
+
results = {metric: float(results[metric] * 100) for metric in metrics}
|
| 349 |
+
logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results))
|
| 350 |
+
return results
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/panoptic_evaluation.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import itertools
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from fvcore.common.file_io import PathManager
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tabulate import tabulate
|
| 13 |
+
|
| 14 |
+
from detectron2.data import MetadataCatalog
|
| 15 |
+
from detectron2.utils import comm
|
| 16 |
+
|
| 17 |
+
from .evaluator import DatasetEvaluator
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class COCOPanopticEvaluator(DatasetEvaluator):
|
| 23 |
+
"""
|
| 24 |
+
Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
|
| 25 |
+
It saves panoptic segmentation prediction in `output_dir`
|
| 26 |
+
|
| 27 |
+
It contains a synchronize call and has to be called from all workers.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dataset_name, output_dir):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
dataset_name (str): name of the dataset
|
| 34 |
+
output_dir (str): output directory to save results for evaluation
|
| 35 |
+
"""
|
| 36 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 37 |
+
self._thing_contiguous_id_to_dataset_id = {
|
| 38 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
| 39 |
+
}
|
| 40 |
+
self._stuff_contiguous_id_to_dataset_id = {
|
| 41 |
+
v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
self._predictions_json = os.path.join(output_dir, "predictions.json")
|
| 45 |
+
|
| 46 |
+
def reset(self):
|
| 47 |
+
self._predictions = []
|
| 48 |
+
|
| 49 |
+
def _convert_category_id(self, segment_info):
|
| 50 |
+
isthing = segment_info.pop("isthing", None)
|
| 51 |
+
if isthing is None:
|
| 52 |
+
# the model produces panoptic category id directly. No more conversion needed
|
| 53 |
+
return segment_info
|
| 54 |
+
if isthing is True:
|
| 55 |
+
segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
|
| 56 |
+
segment_info["category_id"]
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
|
| 60 |
+
segment_info["category_id"]
|
| 61 |
+
]
|
| 62 |
+
return segment_info
|
| 63 |
+
|
| 64 |
+
def process(self, inputs, outputs):
|
| 65 |
+
from panopticapi.utils import id2rgb
|
| 66 |
+
|
| 67 |
+
for input, output in zip(inputs, outputs):
|
| 68 |
+
panoptic_img, segments_info = output["panoptic_seg"]
|
| 69 |
+
panoptic_img = panoptic_img.cpu().numpy()
|
| 70 |
+
|
| 71 |
+
file_name = os.path.basename(input["file_name"])
|
| 72 |
+
file_name_png = os.path.splitext(file_name)[0] + ".png"
|
| 73 |
+
with io.BytesIO() as out:
|
| 74 |
+
Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
|
| 75 |
+
segments_info = [self._convert_category_id(x) for x in segments_info]
|
| 76 |
+
self._predictions.append(
|
| 77 |
+
{
|
| 78 |
+
"image_id": input["image_id"],
|
| 79 |
+
"file_name": file_name_png,
|
| 80 |
+
"png_string": out.getvalue(),
|
| 81 |
+
"segments_info": segments_info,
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def evaluate(self):
|
| 86 |
+
comm.synchronize()
|
| 87 |
+
|
| 88 |
+
self._predictions = comm.gather(self._predictions)
|
| 89 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
| 90 |
+
if not comm.is_main_process():
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
# PanopticApi requires local files
|
| 94 |
+
gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
|
| 95 |
+
gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
|
| 96 |
+
|
| 97 |
+
with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
|
| 98 |
+
logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
|
| 99 |
+
for p in self._predictions:
|
| 100 |
+
with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
|
| 101 |
+
f.write(p.pop("png_string"))
|
| 102 |
+
|
| 103 |
+
with open(gt_json, "r") as f:
|
| 104 |
+
json_data = json.load(f)
|
| 105 |
+
json_data["annotations"] = self._predictions
|
| 106 |
+
with PathManager.open(self._predictions_json, "w") as f:
|
| 107 |
+
f.write(json.dumps(json_data))
|
| 108 |
+
|
| 109 |
+
from panopticapi.evaluation import pq_compute
|
| 110 |
+
|
| 111 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 112 |
+
pq_res = pq_compute(
|
| 113 |
+
gt_json,
|
| 114 |
+
PathManager.get_local_path(self._predictions_json),
|
| 115 |
+
gt_folder=gt_folder,
|
| 116 |
+
pred_folder=pred_dir,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
res = {}
|
| 120 |
+
res["PQ"] = 100 * pq_res["All"]["pq"]
|
| 121 |
+
res["SQ"] = 100 * pq_res["All"]["sq"]
|
| 122 |
+
res["RQ"] = 100 * pq_res["All"]["rq"]
|
| 123 |
+
res["PQ_th"] = 100 * pq_res["Things"]["pq"]
|
| 124 |
+
res["SQ_th"] = 100 * pq_res["Things"]["sq"]
|
| 125 |
+
res["RQ_th"] = 100 * pq_res["Things"]["rq"]
|
| 126 |
+
res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
|
| 127 |
+
res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
|
| 128 |
+
res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
|
| 129 |
+
|
| 130 |
+
results = OrderedDict({"panoptic_seg": res})
|
| 131 |
+
_print_panoptic_results(pq_res)
|
| 132 |
+
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _print_panoptic_results(pq_res):
|
| 137 |
+
headers = ["", "PQ", "SQ", "RQ", "#categories"]
|
| 138 |
+
data = []
|
| 139 |
+
for name in ["All", "Things", "Stuff"]:
|
| 140 |
+
row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
|
| 141 |
+
data.append(row)
|
| 142 |
+
table = tabulate(
|
| 143 |
+
data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
|
| 144 |
+
)
|
| 145 |
+
logger.info("Panoptic Evaluation Results:\n" + table)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
from detectron2.utils.logger import setup_logger
|
| 150 |
+
|
| 151 |
+
logger = setup_logger()
|
| 152 |
+
import argparse
|
| 153 |
+
|
| 154 |
+
parser = argparse.ArgumentParser()
|
| 155 |
+
parser.add_argument("--gt-json")
|
| 156 |
+
parser.add_argument("--gt-dir")
|
| 157 |
+
parser.add_argument("--pred-json")
|
| 158 |
+
parser.add_argument("--pred-dir")
|
| 159 |
+
args = parser.parse_args()
|
| 160 |
+
|
| 161 |
+
from panopticapi.evaluation import pq_compute
|
| 162 |
+
|
| 163 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 164 |
+
pq_res = pq_compute(
|
| 165 |
+
args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
|
| 166 |
+
)
|
| 167 |
+
_print_panoptic_results(pq_res)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/pascal_voc_evaluation.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
import xml.etree.ElementTree as ET
|
| 9 |
+
from collections import OrderedDict, defaultdict
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
import torch
|
| 12 |
+
from fvcore.common.file_io import PathManager
|
| 13 |
+
|
| 14 |
+
from detectron2.data import MetadataCatalog
|
| 15 |
+
from detectron2.utils import comm
|
| 16 |
+
|
| 17 |
+
from .evaluator import DatasetEvaluator
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
| 21 |
+
"""
|
| 22 |
+
Evaluate Pascal VOC AP.
|
| 23 |
+
It contains a synchronization, therefore has to be called from all ranks.
|
| 24 |
+
|
| 25 |
+
Note that this is a rewrite of the official Matlab API.
|
| 26 |
+
The results should be similar, but not identical to the one produced by
|
| 27 |
+
the official API.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dataset_name):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
dataset_name (str): name of the dataset, e.g., "voc_2007_test"
|
| 34 |
+
"""
|
| 35 |
+
self._dataset_name = dataset_name
|
| 36 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 37 |
+
self._anno_file_template = os.path.join(meta.dirname, "Annotations", "{}.xml")
|
| 38 |
+
self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
|
| 39 |
+
self._class_names = meta.thing_classes
|
| 40 |
+
assert meta.year in [2007, 2012], meta.year
|
| 41 |
+
self._is_2007 = meta.year == 2007
|
| 42 |
+
self._cpu_device = torch.device("cpu")
|
| 43 |
+
self._logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
def reset(self):
|
| 46 |
+
self._predictions = defaultdict(list) # class name -> list of prediction strings
|
| 47 |
+
|
| 48 |
+
def process(self, inputs, outputs):
|
| 49 |
+
for input, output in zip(inputs, outputs):
|
| 50 |
+
image_id = input["image_id"]
|
| 51 |
+
instances = output["instances"].to(self._cpu_device)
|
| 52 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
| 53 |
+
scores = instances.scores.tolist()
|
| 54 |
+
classes = instances.pred_classes.tolist()
|
| 55 |
+
for box, score, cls in zip(boxes, scores, classes):
|
| 56 |
+
xmin, ymin, xmax, ymax = box
|
| 57 |
+
# The inverse of data loading logic in `data/pascal_voc.py`
|
| 58 |
+
xmin += 1
|
| 59 |
+
ymin += 1
|
| 60 |
+
self._predictions[cls].append(
|
| 61 |
+
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def evaluate(self):
|
| 65 |
+
"""
|
| 66 |
+
Returns:
|
| 67 |
+
dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
|
| 68 |
+
"""
|
| 69 |
+
all_predictions = comm.gather(self._predictions, dst=0)
|
| 70 |
+
if not comm.is_main_process():
|
| 71 |
+
return
|
| 72 |
+
predictions = defaultdict(list)
|
| 73 |
+
for predictions_per_rank in all_predictions:
|
| 74 |
+
for clsid, lines in predictions_per_rank.items():
|
| 75 |
+
predictions[clsid].extend(lines)
|
| 76 |
+
del all_predictions
|
| 77 |
+
|
| 78 |
+
self._logger.info(
|
| 79 |
+
"Evaluating {} using {} metric. "
|
| 80 |
+
"Note that results do not use the official Matlab API.".format(
|
| 81 |
+
self._dataset_name, 2007 if self._is_2007 else 2012
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
|
| 86 |
+
res_file_template = os.path.join(dirname, "{}.txt")
|
| 87 |
+
|
| 88 |
+
aps = defaultdict(list) # iou -> ap per class
|
| 89 |
+
for cls_id, cls_name in enumerate(self._class_names):
|
| 90 |
+
lines = predictions.get(cls_id, [""])
|
| 91 |
+
|
| 92 |
+
with open(res_file_template.format(cls_name), "w") as f:
|
| 93 |
+
f.write("\n".join(lines))
|
| 94 |
+
|
| 95 |
+
for thresh in range(50, 100, 5):
|
| 96 |
+
rec, prec, ap = voc_eval(
|
| 97 |
+
res_file_template,
|
| 98 |
+
self._anno_file_template,
|
| 99 |
+
self._image_set_path,
|
| 100 |
+
cls_name,
|
| 101 |
+
ovthresh=thresh / 100.0,
|
| 102 |
+
use_07_metric=self._is_2007,
|
| 103 |
+
)
|
| 104 |
+
aps[thresh].append(ap * 100)
|
| 105 |
+
|
| 106 |
+
ret = OrderedDict()
|
| 107 |
+
mAP = {iou: np.mean(x) for iou, x in aps.items()}
|
| 108 |
+
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
|
| 109 |
+
return ret
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
##############################################################################
|
| 113 |
+
#
|
| 114 |
+
# Below code is modified from
|
| 115 |
+
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
|
| 116 |
+
# --------------------------------------------------------
|
| 117 |
+
# Fast/er R-CNN
|
| 118 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 119 |
+
# Written by Bharath Hariharan
|
| 120 |
+
# --------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
"""Python implementation of the PASCAL VOC devkit's AP evaluation code."""
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@lru_cache(maxsize=None)
|
| 126 |
+
def parse_rec(filename):
|
| 127 |
+
"""Parse a PASCAL VOC xml file."""
|
| 128 |
+
with PathManager.open(filename) as f:
|
| 129 |
+
tree = ET.parse(f)
|
| 130 |
+
objects = []
|
| 131 |
+
for obj in tree.findall("object"):
|
| 132 |
+
obj_struct = {}
|
| 133 |
+
obj_struct["name"] = obj.find("name").text
|
| 134 |
+
obj_struct["pose"] = obj.find("pose").text
|
| 135 |
+
obj_struct["truncated"] = int(obj.find("truncated").text)
|
| 136 |
+
obj_struct["difficult"] = int(obj.find("difficult").text)
|
| 137 |
+
bbox = obj.find("bndbox")
|
| 138 |
+
obj_struct["bbox"] = [
|
| 139 |
+
int(bbox.find("xmin").text),
|
| 140 |
+
int(bbox.find("ymin").text),
|
| 141 |
+
int(bbox.find("xmax").text),
|
| 142 |
+
int(bbox.find("ymax").text),
|
| 143 |
+
]
|
| 144 |
+
objects.append(obj_struct)
|
| 145 |
+
|
| 146 |
+
return objects
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
| 150 |
+
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
|
| 151 |
+
the VOC 07 11-point method (default:False).
|
| 152 |
+
"""
|
| 153 |
+
if use_07_metric:
|
| 154 |
+
# 11 point metric
|
| 155 |
+
ap = 0.0
|
| 156 |
+
for t in np.arange(0.0, 1.1, 0.1):
|
| 157 |
+
if np.sum(rec >= t) == 0:
|
| 158 |
+
p = 0
|
| 159 |
+
else:
|
| 160 |
+
p = np.max(prec[rec >= t])
|
| 161 |
+
ap = ap + p / 11.0
|
| 162 |
+
else:
|
| 163 |
+
# correct AP calculation
|
| 164 |
+
# first append sentinel values at the end
|
| 165 |
+
mrec = np.concatenate(([0.0], rec, [1.0]))
|
| 166 |
+
mpre = np.concatenate(([0.0], prec, [0.0]))
|
| 167 |
+
|
| 168 |
+
# compute the precision envelope
|
| 169 |
+
for i in range(mpre.size - 1, 0, -1):
|
| 170 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
| 171 |
+
|
| 172 |
+
# to calculate area under PR curve, look for points
|
| 173 |
+
# where X axis (recall) changes value
|
| 174 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
| 175 |
+
|
| 176 |
+
# and sum (\Delta recall) * prec
|
| 177 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
| 178 |
+
return ap
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
|
| 182 |
+
"""rec, prec, ap = voc_eval(detpath,
|
| 183 |
+
annopath,
|
| 184 |
+
imagesetfile,
|
| 185 |
+
classname,
|
| 186 |
+
[ovthresh],
|
| 187 |
+
[use_07_metric])
|
| 188 |
+
|
| 189 |
+
Top level function that does the PASCAL VOC evaluation.
|
| 190 |
+
|
| 191 |
+
detpath: Path to detections
|
| 192 |
+
detpath.format(classname) should produce the detection results file.
|
| 193 |
+
annopath: Path to annotations
|
| 194 |
+
annopath.format(imagename) should be the xml annotations file.
|
| 195 |
+
imagesetfile: Text file containing the list of images, one image per line.
|
| 196 |
+
classname: Category name (duh)
|
| 197 |
+
[ovthresh]: Overlap threshold (default = 0.5)
|
| 198 |
+
[use_07_metric]: Whether to use VOC07's 11 point AP computation
|
| 199 |
+
(default False)
|
| 200 |
+
"""
|
| 201 |
+
# assumes detections are in detpath.format(classname)
|
| 202 |
+
# assumes annotations are in annopath.format(imagename)
|
| 203 |
+
# assumes imagesetfile is a text file with each line an image name
|
| 204 |
+
|
| 205 |
+
# first load gt
|
| 206 |
+
# read list of images
|
| 207 |
+
with PathManager.open(imagesetfile, "r") as f:
|
| 208 |
+
lines = f.readlines()
|
| 209 |
+
imagenames = [x.strip() for x in lines]
|
| 210 |
+
|
| 211 |
+
# load annots
|
| 212 |
+
recs = {}
|
| 213 |
+
for imagename in imagenames:
|
| 214 |
+
recs[imagename] = parse_rec(annopath.format(imagename))
|
| 215 |
+
|
| 216 |
+
# extract gt objects for this class
|
| 217 |
+
class_recs = {}
|
| 218 |
+
npos = 0
|
| 219 |
+
for imagename in imagenames:
|
| 220 |
+
R = [obj for obj in recs[imagename] if obj["name"] == classname]
|
| 221 |
+
bbox = np.array([x["bbox"] for x in R])
|
| 222 |
+
difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
|
| 223 |
+
# difficult = np.array([False for x in R]).astype(np.bool) # treat all "difficult" as GT
|
| 224 |
+
det = [False] * len(R)
|
| 225 |
+
npos = npos + sum(~difficult)
|
| 226 |
+
class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
|
| 227 |
+
|
| 228 |
+
# read dets
|
| 229 |
+
detfile = detpath.format(classname)
|
| 230 |
+
with open(detfile, "r") as f:
|
| 231 |
+
lines = f.readlines()
|
| 232 |
+
|
| 233 |
+
splitlines = [x.strip().split(" ") for x in lines]
|
| 234 |
+
image_ids = [x[0] for x in splitlines]
|
| 235 |
+
confidence = np.array([float(x[1]) for x in splitlines])
|
| 236 |
+
BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)
|
| 237 |
+
|
| 238 |
+
# sort by confidence
|
| 239 |
+
sorted_ind = np.argsort(-confidence)
|
| 240 |
+
BB = BB[sorted_ind, :]
|
| 241 |
+
image_ids = [image_ids[x] for x in sorted_ind]
|
| 242 |
+
|
| 243 |
+
# go down dets and mark TPs and FPs
|
| 244 |
+
nd = len(image_ids)
|
| 245 |
+
tp = np.zeros(nd)
|
| 246 |
+
fp = np.zeros(nd)
|
| 247 |
+
for d in range(nd):
|
| 248 |
+
R = class_recs[image_ids[d]]
|
| 249 |
+
bb = BB[d, :].astype(float)
|
| 250 |
+
ovmax = -np.inf
|
| 251 |
+
BBGT = R["bbox"].astype(float)
|
| 252 |
+
|
| 253 |
+
if BBGT.size > 0:
|
| 254 |
+
# compute overlaps
|
| 255 |
+
# intersection
|
| 256 |
+
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
| 257 |
+
iymin = np.maximum(BBGT[:, 1], bb[1])
|
| 258 |
+
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
| 259 |
+
iymax = np.minimum(BBGT[:, 3], bb[3])
|
| 260 |
+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
| 261 |
+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
| 262 |
+
inters = iw * ih
|
| 263 |
+
|
| 264 |
+
# union
|
| 265 |
+
uni = (
|
| 266 |
+
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
| 267 |
+
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
|
| 268 |
+
- inters
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
overlaps = inters / uni
|
| 272 |
+
ovmax = np.max(overlaps)
|
| 273 |
+
jmax = np.argmax(overlaps)
|
| 274 |
+
|
| 275 |
+
if ovmax > ovthresh:
|
| 276 |
+
if not R["difficult"][jmax]:
|
| 277 |
+
if not R["det"][jmax]:
|
| 278 |
+
tp[d] = 1.0
|
| 279 |
+
R["det"][jmax] = 1
|
| 280 |
+
else:
|
| 281 |
+
fp[d] = 1.0
|
| 282 |
+
else:
|
| 283 |
+
fp[d] = 1.0
|
| 284 |
+
|
| 285 |
+
# compute precision recall
|
| 286 |
+
fp = np.cumsum(fp)
|
| 287 |
+
tp = np.cumsum(tp)
|
| 288 |
+
rec = tp / float(npos)
|
| 289 |
+
# avoid divide by zero in case the first detection matches a difficult
|
| 290 |
+
# ground truth
|
| 291 |
+
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
| 292 |
+
ap = voc_ap(rec, prec, use_07_metric)
|
| 293 |
+
|
| 294 |
+
return rec, prec, ap
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/rotated_coco_evaluation.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from fvcore.common.file_io import PathManager
|
| 8 |
+
from pycocotools.cocoeval import COCOeval, maskUtils
|
| 9 |
+
|
| 10 |
+
from detectron2.structures import BoxMode, RotatedBoxes, pairwise_iou_rotated
|
| 11 |
+
|
| 12 |
+
from .coco_evaluation import COCOEvaluator
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RotatedCOCOeval(COCOeval):
|
| 16 |
+
@staticmethod
|
| 17 |
+
def is_rotated(box_list):
|
| 18 |
+
if type(box_list) == np.ndarray:
|
| 19 |
+
return box_list.shape[1] == 5
|
| 20 |
+
elif type(box_list) == list:
|
| 21 |
+
if box_list == []: # cannot decide the box_dim
|
| 22 |
+
return False
|
| 23 |
+
return np.all(
|
| 24 |
+
np.array(
|
| 25 |
+
[
|
| 26 |
+
(len(obj) == 5) and ((type(obj) == list) or (type(obj) == np.ndarray))
|
| 27 |
+
for obj in box_list
|
| 28 |
+
]
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def boxlist_to_tensor(boxlist, output_box_dim):
|
| 35 |
+
if type(boxlist) == np.ndarray:
|
| 36 |
+
box_tensor = torch.from_numpy(boxlist)
|
| 37 |
+
elif type(boxlist) == list:
|
| 38 |
+
if boxlist == []:
|
| 39 |
+
return torch.zeros((0, output_box_dim), dtype=torch.float32)
|
| 40 |
+
else:
|
| 41 |
+
box_tensor = torch.FloatTensor(boxlist)
|
| 42 |
+
else:
|
| 43 |
+
raise Exception("Unrecognized boxlist type")
|
| 44 |
+
|
| 45 |
+
input_box_dim = box_tensor.shape[1]
|
| 46 |
+
if input_box_dim != output_box_dim:
|
| 47 |
+
if input_box_dim == 4 and output_box_dim == 5:
|
| 48 |
+
box_tensor = BoxMode.convert(box_tensor, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS)
|
| 49 |
+
else:
|
| 50 |
+
raise Exception(
|
| 51 |
+
"Unable to convert from {}-dim box to {}-dim box".format(
|
| 52 |
+
input_box_dim, output_box_dim
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
return box_tensor
|
| 56 |
+
|
| 57 |
+
def compute_iou_dt_gt(self, dt, gt, is_crowd):
|
| 58 |
+
if self.is_rotated(dt) or self.is_rotated(gt):
|
| 59 |
+
# TODO: take is_crowd into consideration
|
| 60 |
+
assert all(c == 0 for c in is_crowd)
|
| 61 |
+
dt = RotatedBoxes(self.boxlist_to_tensor(dt, output_box_dim=5))
|
| 62 |
+
gt = RotatedBoxes(self.boxlist_to_tensor(gt, output_box_dim=5))
|
| 63 |
+
return pairwise_iou_rotated(dt, gt)
|
| 64 |
+
else:
|
| 65 |
+
# This is the same as the classical COCO evaluation
|
| 66 |
+
return maskUtils.iou(dt, gt, is_crowd)
|
| 67 |
+
|
| 68 |
+
def computeIoU(self, imgId, catId):
|
| 69 |
+
p = self.params
|
| 70 |
+
if p.useCats:
|
| 71 |
+
gt = self._gts[imgId, catId]
|
| 72 |
+
dt = self._dts[imgId, catId]
|
| 73 |
+
else:
|
| 74 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 75 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 76 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 77 |
+
return []
|
| 78 |
+
inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 79 |
+
dt = [dt[i] for i in inds]
|
| 80 |
+
if len(dt) > p.maxDets[-1]:
|
| 81 |
+
dt = dt[0 : p.maxDets[-1]]
|
| 82 |
+
|
| 83 |
+
assert p.iouType == "bbox", "unsupported iouType for iou computation"
|
| 84 |
+
|
| 85 |
+
g = [g["bbox"] for g in gt]
|
| 86 |
+
d = [d["bbox"] for d in dt]
|
| 87 |
+
|
| 88 |
+
# compute iou between each dt and gt region
|
| 89 |
+
iscrowd = [int(o["iscrowd"]) for o in gt]
|
| 90 |
+
|
| 91 |
+
# Note: this function is copied from cocoeval.py in cocoapi
|
| 92 |
+
# and the major difference is here.
|
| 93 |
+
ious = self.compute_iou_dt_gt(d, g, iscrowd)
|
| 94 |
+
return ious
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class RotatedCOCOEvaluator(COCOEvaluator):
|
| 98 |
+
"""
|
| 99 |
+
Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs,
|
| 100 |
+
with rotated boxes support.
|
| 101 |
+
Note: this uses IOU only and does not consider angle differences.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def process(self, inputs, outputs):
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
| 108 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 109 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 110 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
| 111 |
+
"instances" that contains :class:`Instances`.
|
| 112 |
+
"""
|
| 113 |
+
for input, output in zip(inputs, outputs):
|
| 114 |
+
prediction = {"image_id": input["image_id"]}
|
| 115 |
+
|
| 116 |
+
if "instances" in output:
|
| 117 |
+
instances = output["instances"].to(self._cpu_device)
|
| 118 |
+
|
| 119 |
+
prediction["instances"] = self.instances_to_json(instances, input["image_id"])
|
| 120 |
+
if "proposals" in output:
|
| 121 |
+
prediction["proposals"] = output["proposals"].to(self._cpu_device)
|
| 122 |
+
self._predictions.append(prediction)
|
| 123 |
+
|
| 124 |
+
def instances_to_json(self, instances, img_id):
|
| 125 |
+
num_instance = len(instances)
|
| 126 |
+
if num_instance == 0:
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
| 130 |
+
if boxes.shape[1] == 4:
|
| 131 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 132 |
+
boxes = boxes.tolist()
|
| 133 |
+
scores = instances.scores.tolist()
|
| 134 |
+
classes = instances.pred_classes.tolist()
|
| 135 |
+
|
| 136 |
+
results = []
|
| 137 |
+
for k in range(num_instance):
|
| 138 |
+
result = {
|
| 139 |
+
"image_id": img_id,
|
| 140 |
+
"category_id": classes[k],
|
| 141 |
+
"bbox": boxes[k],
|
| 142 |
+
"score": scores[k],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
results.append(result)
|
| 146 |
+
return results
|
| 147 |
+
|
| 148 |
+
def _eval_predictions(self, tasks, predictions):
|
| 149 |
+
"""
|
| 150 |
+
Evaluate predictions on the given tasks.
|
| 151 |
+
Fill self._results with the metrics of the tasks.
|
| 152 |
+
"""
|
| 153 |
+
self._logger.info("Preparing results for COCO format ...")
|
| 154 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
| 155 |
+
|
| 156 |
+
# unmap the category ids for COCO
|
| 157 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
| 158 |
+
reverse_id_mapping = {
|
| 159 |
+
v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
|
| 160 |
+
}
|
| 161 |
+
for result in coco_results:
|
| 162 |
+
result["category_id"] = reverse_id_mapping[result["category_id"]]
|
| 163 |
+
|
| 164 |
+
if self._output_dir:
|
| 165 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
| 166 |
+
self._logger.info("Saving results to {}".format(file_path))
|
| 167 |
+
with PathManager.open(file_path, "w") as f:
|
| 168 |
+
f.write(json.dumps(coco_results))
|
| 169 |
+
f.flush()
|
| 170 |
+
|
| 171 |
+
if not self._do_evaluation:
|
| 172 |
+
self._logger.info("Annotations are not available for evaluation.")
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
self._logger.info("Evaluating predictions ...")
|
| 176 |
+
for task in sorted(tasks):
|
| 177 |
+
assert task == "bbox", "Task {} is not supported".format(task)
|
| 178 |
+
coco_eval = (
|
| 179 |
+
self._evaluate_predictions_on_coco(self._coco_api, coco_results)
|
| 180 |
+
if len(coco_results) > 0
|
| 181 |
+
else None # cocoapi does not handle empty results very well
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
res = self._derive_coco_results(
|
| 185 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
| 186 |
+
)
|
| 187 |
+
self._results[task] = res
|
| 188 |
+
|
| 189 |
+
def _evaluate_predictions_on_coco(self, coco_gt, coco_results):
|
| 190 |
+
"""
|
| 191 |
+
Evaluate the coco results using COCOEval API.
|
| 192 |
+
"""
|
| 193 |
+
assert len(coco_results) > 0
|
| 194 |
+
|
| 195 |
+
coco_dt = coco_gt.loadRes(coco_results)
|
| 196 |
+
|
| 197 |
+
# Only bbox is supported for now
|
| 198 |
+
coco_eval = RotatedCOCOeval(coco_gt, coco_dt, iouType="bbox")
|
| 199 |
+
|
| 200 |
+
coco_eval.evaluate()
|
| 201 |
+
coco_eval.accumulate()
|
| 202 |
+
coco_eval.summarize()
|
| 203 |
+
|
| 204 |
+
return coco_eval
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/sem_seg_evaluation.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import pycocotools.mask as mask_util
|
| 10 |
+
import torch
|
| 11 |
+
from fvcore.common.file_io import PathManager
|
| 12 |
+
|
| 13 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
| 14 |
+
from detectron2.utils.comm import all_gather, is_main_process, synchronize
|
| 15 |
+
|
| 16 |
+
from .evaluator import DatasetEvaluator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SemSegEvaluator(DatasetEvaluator):
|
| 20 |
+
"""
|
| 21 |
+
Evaluate semantic segmentation
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dataset_name, distributed, num_classes, ignore_label=255, output_dir=None):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
dataset_name (str): name of the dataset to be evaluated.
|
| 28 |
+
distributed (True): if True, will collect results from all ranks for evaluation.
|
| 29 |
+
Otherwise, will evaluate the results in the current process.
|
| 30 |
+
num_classes (int): number of classes
|
| 31 |
+
ignore_label (int): value in semantic segmentation ground truth. Predictions for the
|
| 32 |
+
corresponding pixels should be ignored.
|
| 33 |
+
output_dir (str): an output directory to dump results.
|
| 34 |
+
"""
|
| 35 |
+
self._dataset_name = dataset_name
|
| 36 |
+
self._distributed = distributed
|
| 37 |
+
self._output_dir = output_dir
|
| 38 |
+
self._num_classes = num_classes
|
| 39 |
+
self._ignore_label = ignore_label
|
| 40 |
+
self._N = num_classes + 1
|
| 41 |
+
|
| 42 |
+
self._cpu_device = torch.device("cpu")
|
| 43 |
+
self._logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
self.input_file_to_gt_file = {
|
| 46 |
+
dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
|
| 47 |
+
for dataset_record in DatasetCatalog.get(dataset_name)
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 51 |
+
# Dict that maps contiguous training ids to COCO category ids
|
| 52 |
+
try:
|
| 53 |
+
c2d = meta.stuff_dataset_id_to_contiguous_id
|
| 54 |
+
self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
|
| 55 |
+
except AttributeError:
|
| 56 |
+
self._contiguous_id_to_dataset_id = None
|
| 57 |
+
self._class_names = meta.stuff_classes
|
| 58 |
+
|
| 59 |
+
def reset(self):
|
| 60 |
+
self._conf_matrix = np.zeros((self._N, self._N), dtype=np.int64)
|
| 61 |
+
self._predictions = []
|
| 62 |
+
|
| 63 |
+
def process(self, inputs, outputs):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
inputs: the inputs to a model.
|
| 67 |
+
It is a list of dicts. Each dict corresponds to an image and
|
| 68 |
+
contains keys like "height", "width", "file_name".
|
| 69 |
+
outputs: the outputs of a model. It is either list of semantic segmentation predictions
|
| 70 |
+
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
|
| 71 |
+
segmentation prediction in the same format.
|
| 72 |
+
"""
|
| 73 |
+
for input, output in zip(inputs, outputs):
|
| 74 |
+
output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
|
| 75 |
+
pred = np.array(output, dtype=np.int)
|
| 76 |
+
with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
|
| 77 |
+
gt = np.array(Image.open(f), dtype=np.int)
|
| 78 |
+
|
| 79 |
+
gt[gt == self._ignore_label] = self._num_classes
|
| 80 |
+
|
| 81 |
+
self._conf_matrix += np.bincount(
|
| 82 |
+
self._N * pred.reshape(-1) + gt.reshape(-1), minlength=self._N ** 2
|
| 83 |
+
).reshape(self._N, self._N)
|
| 84 |
+
|
| 85 |
+
self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
|
| 86 |
+
|
| 87 |
+
def evaluate(self):
|
| 88 |
+
"""
|
| 89 |
+
Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
|
| 90 |
+
|
| 91 |
+
* Mean intersection-over-union averaged across classes (mIoU)
|
| 92 |
+
* Frequency Weighted IoU (fwIoU)
|
| 93 |
+
* Mean pixel accuracy averaged across classes (mACC)
|
| 94 |
+
* Pixel Accuracy (pACC)
|
| 95 |
+
"""
|
| 96 |
+
if self._distributed:
|
| 97 |
+
synchronize()
|
| 98 |
+
conf_matrix_list = all_gather(self._conf_matrix)
|
| 99 |
+
self._predictions = all_gather(self._predictions)
|
| 100 |
+
self._predictions = list(itertools.chain(*self._predictions))
|
| 101 |
+
if not is_main_process():
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
self._conf_matrix = np.zeros_like(self._conf_matrix)
|
| 105 |
+
for conf_matrix in conf_matrix_list:
|
| 106 |
+
self._conf_matrix += conf_matrix
|
| 107 |
+
|
| 108 |
+
if self._output_dir:
|
| 109 |
+
PathManager.mkdirs(self._output_dir)
|
| 110 |
+
file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
|
| 111 |
+
with PathManager.open(file_path, "w") as f:
|
| 112 |
+
f.write(json.dumps(self._predictions))
|
| 113 |
+
|
| 114 |
+
acc = np.full(self._num_classes, np.nan, dtype=np.float)
|
| 115 |
+
iou = np.full(self._num_classes, np.nan, dtype=np.float)
|
| 116 |
+
tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
|
| 117 |
+
pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
|
| 118 |
+
class_weights = pos_gt / np.sum(pos_gt)
|
| 119 |
+
pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
|
| 120 |
+
acc_valid = pos_gt > 0
|
| 121 |
+
acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
|
| 122 |
+
iou_valid = (pos_gt + pos_pred) > 0
|
| 123 |
+
union = pos_gt + pos_pred - tp
|
| 124 |
+
iou[acc_valid] = tp[acc_valid] / union[acc_valid]
|
| 125 |
+
macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
|
| 126 |
+
miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
|
| 127 |
+
fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
|
| 128 |
+
pacc = np.sum(tp) / np.sum(pos_gt)
|
| 129 |
+
|
| 130 |
+
res = {}
|
| 131 |
+
res["mIoU"] = 100 * miou
|
| 132 |
+
res["fwIoU"] = 100 * fiou
|
| 133 |
+
for i, name in enumerate(self._class_names):
|
| 134 |
+
res["IoU-{}".format(name)] = 100 * iou[i]
|
| 135 |
+
res["mACC"] = 100 * macc
|
| 136 |
+
res["pACC"] = 100 * pacc
|
| 137 |
+
for i, name in enumerate(self._class_names):
|
| 138 |
+
res["ACC-{}".format(name)] = 100 * acc[i]
|
| 139 |
+
|
| 140 |
+
if self._output_dir:
|
| 141 |
+
file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
|
| 142 |
+
with PathManager.open(file_path, "wb") as f:
|
| 143 |
+
torch.save(res, f)
|
| 144 |
+
results = OrderedDict({"sem_seg": res})
|
| 145 |
+
self._logger.info(results)
|
| 146 |
+
return results
|
| 147 |
+
|
| 148 |
+
def encode_json_sem_seg(self, sem_seg, input_file_name):
|
| 149 |
+
"""
|
| 150 |
+
Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
|
| 151 |
+
See http://cocodataset.org/#format-results
|
| 152 |
+
"""
|
| 153 |
+
json_list = []
|
| 154 |
+
for label in np.unique(sem_seg):
|
| 155 |
+
if self._contiguous_id_to_dataset_id is not None:
|
| 156 |
+
assert (
|
| 157 |
+
label in self._contiguous_id_to_dataset_id
|
| 158 |
+
), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
|
| 159 |
+
dataset_id = self._contiguous_id_to_dataset_id[label]
|
| 160 |
+
else:
|
| 161 |
+
dataset_id = int(label)
|
| 162 |
+
mask = (sem_seg == label).astype(np.uint8)
|
| 163 |
+
mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
|
| 164 |
+
mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
|
| 165 |
+
json_list.append(
|
| 166 |
+
{"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
|
| 167 |
+
)
|
| 168 |
+
return json_list
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/evaluation/testing.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pprint
|
| 5 |
+
import sys
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from collections.abc import Mapping
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def print_csv_format(results):
|
| 11 |
+
"""
|
| 12 |
+
Print main metrics in a format similar to Detectron,
|
| 13 |
+
so that they are easy to copypaste into a spreadsheet.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
results (OrderedDict[dict]): task_name -> {metric -> score}
|
| 17 |
+
"""
|
| 18 |
+
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
for task, res in results.items():
|
| 21 |
+
# Don't print "AP-category" metrics since they are usually not tracked.
|
| 22 |
+
important_res = [(k, v) for k, v in res.items() if "-" not in k]
|
| 23 |
+
logger.info("copypaste: Task: {}".format(task))
|
| 24 |
+
logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
|
| 25 |
+
logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def verify_results(cfg, results):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
results (OrderedDict[dict]): task_name -> {metric -> score}
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
bool: whether the verification succeeds or not
|
| 35 |
+
"""
|
| 36 |
+
expected_results = cfg.TEST.EXPECTED_RESULTS
|
| 37 |
+
if not len(expected_results):
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
ok = True
|
| 41 |
+
for task, metric, expected, tolerance in expected_results:
|
| 42 |
+
actual = results[task][metric]
|
| 43 |
+
if not np.isfinite(actual):
|
| 44 |
+
ok = False
|
| 45 |
+
diff = abs(actual - expected)
|
| 46 |
+
if diff > tolerance:
|
| 47 |
+
ok = False
|
| 48 |
+
|
| 49 |
+
logger = logging.getLogger(__name__)
|
| 50 |
+
if not ok:
|
| 51 |
+
logger.error("Result verification failed!")
|
| 52 |
+
logger.error("Expected Results: " + str(expected_results))
|
| 53 |
+
logger.error("Actual Results: " + pprint.pformat(results))
|
| 54 |
+
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
else:
|
| 57 |
+
logger.info("Results verification passed.")
|
| 58 |
+
return ok
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def flatten_results_dict(results):
|
| 62 |
+
"""
|
| 63 |
+
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
| 64 |
+
If results[k1][k2][k3] = v, the returned dict will have the entry
|
| 65 |
+
{"k1/k2/k3": v}.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
results (dict):
|
| 69 |
+
"""
|
| 70 |
+
r = {}
|
| 71 |
+
for k, v in results.items():
|
| 72 |
+
if isinstance(v, Mapping):
|
| 73 |
+
v = flatten_results_dict(v)
|
| 74 |
+
for kk, vv in v.items():
|
| 75 |
+
r[k + "/" + kk] = vv
|
| 76 |
+
else:
|
| 77 |
+
r[k] = v
|
| 78 |
+
return r
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
This directory contains code to prepare a detectron2 model for deployment.
|
| 3 |
+
Currently it supports exporting a detectron2 model to Caffe2 format through ONNX.
|
| 4 |
+
|
| 5 |
+
Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
### Acknowledgements
|
| 9 |
+
|
| 10 |
+
Thanks to Mobile Vision team at Facebook for developing the conversion tools.
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from .api import *
|
| 4 |
+
|
| 5 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/api.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from caffe2.proto import caffe2_pb2
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from detectron2.config import CfgNode as CN
|
| 10 |
+
|
| 11 |
+
from .caffe2_export import export_caffe2_detection_model
|
| 12 |
+
from .caffe2_export import export_onnx_model as export_onnx_model_impl
|
| 13 |
+
from .caffe2_export import run_and_save_graph
|
| 14 |
+
from .caffe2_inference import ProtobufDetectionModel
|
| 15 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 16 |
+
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"add_export_config",
|
| 20 |
+
"export_caffe2_model",
|
| 21 |
+
"Caffe2Model",
|
| 22 |
+
"export_onnx_model",
|
| 23 |
+
"Caffe2Tracer",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def add_export_config(cfg):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
cfg (CfgNode): a detectron2 config
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
CfgNode: an updated config with new options that will be used
|
| 34 |
+
by :class:`Caffe2Tracer`.
|
| 35 |
+
"""
|
| 36 |
+
is_frozen = cfg.is_frozen()
|
| 37 |
+
cfg.defrost()
|
| 38 |
+
cfg.EXPORT_CAFFE2 = CN()
|
| 39 |
+
cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
|
| 40 |
+
if is_frozen:
|
| 41 |
+
cfg.freeze()
|
| 42 |
+
return cfg
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Caffe2Tracer:
|
| 46 |
+
"""
|
| 47 |
+
Make a detectron2 model traceable with caffe2 style.
|
| 48 |
+
|
| 49 |
+
An original detectron2 model may not be traceable, or
|
| 50 |
+
cannot be deployed directly after being traced, due to some reasons:
|
| 51 |
+
1. control flow in some ops
|
| 52 |
+
2. custom ops
|
| 53 |
+
3. complicated pre/post processing
|
| 54 |
+
|
| 55 |
+
This class provides a traceable version of a detectron2 model by:
|
| 56 |
+
1. Rewrite parts of the model using ops in caffe2. Note that some ops do
|
| 57 |
+
not have GPU implementation.
|
| 58 |
+
2. Define the inputs "after pre-processing" as inputs to the model
|
| 59 |
+
3. Remove post-processing and produce raw layer outputs
|
| 60 |
+
|
| 61 |
+
More specifically about inputs: all builtin models take two input tensors.
|
| 62 |
+
(1) NCHW float "data" which is an image (usually in [0, 255])
|
| 63 |
+
(2) Nx3 float "im_info", each row of which is (height, width, 1.0)
|
| 64 |
+
|
| 65 |
+
After making a traceable model, the class provide methods to export such a
|
| 66 |
+
model to different deployment formats.
|
| 67 |
+
|
| 68 |
+
The class currently only supports models using builtin meta architectures.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, cfg, model, inputs):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
cfg (CfgNode): a detectron2 config, with extra export-related options
|
| 75 |
+
added by :func:`add_export_config`.
|
| 76 |
+
model (nn.Module): a model built by
|
| 77 |
+
:func:`detectron2.modeling.build_model`.
|
| 78 |
+
inputs: sample inputs that the given model takes for inference.
|
| 79 |
+
Will be used to trace the model.
|
| 80 |
+
"""
|
| 81 |
+
assert isinstance(cfg, CN), cfg
|
| 82 |
+
assert isinstance(model, torch.nn.Module), type(model)
|
| 83 |
+
if "EXPORT_CAFFE2" not in cfg:
|
| 84 |
+
cfg = add_export_config(cfg) # will just the defaults
|
| 85 |
+
|
| 86 |
+
self.cfg = cfg
|
| 87 |
+
self.model = model
|
| 88 |
+
self.inputs = inputs
|
| 89 |
+
|
| 90 |
+
def _get_traceable(self):
|
| 91 |
+
# TODO how to make it extensible to support custom models
|
| 92 |
+
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[self.cfg.MODEL.META_ARCHITECTURE]
|
| 93 |
+
traceable_model = C2MetaArch(self.cfg, copy.deepcopy(self.model))
|
| 94 |
+
traceable_inputs = traceable_model.get_caffe2_inputs(self.inputs)
|
| 95 |
+
return traceable_model, traceable_inputs
|
| 96 |
+
|
| 97 |
+
def export_caffe2(self):
|
| 98 |
+
"""
|
| 99 |
+
Export the model to Caffe2's protobuf format.
|
| 100 |
+
The returned object can be saved with `.save_protobuf()` method.
|
| 101 |
+
The result can be loaded and executed using Caffe2 runtime.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Caffe2Model
|
| 105 |
+
"""
|
| 106 |
+
model, inputs = self._get_traceable()
|
| 107 |
+
predict_net, init_net = export_caffe2_detection_model(model, inputs)
|
| 108 |
+
return Caffe2Model(predict_net, init_net)
|
| 109 |
+
|
| 110 |
+
def export_onnx(self):
|
| 111 |
+
"""
|
| 112 |
+
Export the model to ONNX format.
|
| 113 |
+
Note that the exported model contains custom ops only available in caffe2, therefore it
|
| 114 |
+
cannot be directly executed by other runtime. Post-processing or transformation passes
|
| 115 |
+
may be applied on the model to accommodate different runtimes.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
onnx.ModelProto: an onnx model.
|
| 119 |
+
"""
|
| 120 |
+
model, inputs = self._get_traceable()
|
| 121 |
+
return export_onnx_model_impl(model, (inputs,))
|
| 122 |
+
|
| 123 |
+
def export_torchscript(self):
|
| 124 |
+
"""
|
| 125 |
+
Export the model to a `torch.jit.TracedModule` by tracing.
|
| 126 |
+
The returned object can be saved to a file by ".save()".
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.jit.TracedModule: a torch TracedModule
|
| 130 |
+
"""
|
| 131 |
+
model, inputs = self._get_traceable()
|
| 132 |
+
logger = logging.getLogger(__name__)
|
| 133 |
+
logger.info("Tracing the model with torch.jit.trace ...")
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
return torch.jit.trace(model, (inputs,), optimize=True)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def export_caffe2_model(cfg, model, inputs):
|
| 139 |
+
"""
|
| 140 |
+
Export a detectron2 model to caffe2 format.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
cfg (CfgNode): a detectron2 config, with extra export-related options
|
| 144 |
+
added by :func:`add_export_config`.
|
| 145 |
+
model (nn.Module): a model built by
|
| 146 |
+
:func:`detectron2.modeling.build_model`.
|
| 147 |
+
It will be modified by this function.
|
| 148 |
+
inputs: sample inputs that the given model takes for inference.
|
| 149 |
+
Will be used to trace the model.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Caffe2Model
|
| 153 |
+
"""
|
| 154 |
+
return Caffe2Tracer(cfg, model, inputs).export_caffe2()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def export_onnx_model(cfg, model, inputs):
|
| 158 |
+
"""
|
| 159 |
+
Export a detectron2 model to ONNX format.
|
| 160 |
+
Note that the exported model contains custom ops only available in caffe2, therefore it
|
| 161 |
+
cannot be directly executed by other runtime. Post-processing or transformation passes
|
| 162 |
+
may be applied on the model to accommodate different runtimes.
|
| 163 |
+
Args:
|
| 164 |
+
cfg (CfgNode): a detectron2 config, with extra export-related options
|
| 165 |
+
added by :func:`add_export_config`.
|
| 166 |
+
model (nn.Module): a model built by
|
| 167 |
+
:func:`detectron2.modeling.build_model`.
|
| 168 |
+
It will be modified by this function.
|
| 169 |
+
inputs: sample inputs that the given model takes for inference.
|
| 170 |
+
Will be used to trace the model.
|
| 171 |
+
Returns:
|
| 172 |
+
onnx.ModelProto: an onnx model.
|
| 173 |
+
"""
|
| 174 |
+
return Caffe2Tracer(cfg, model, inputs).export_onnx()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class Caffe2Model(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
A wrapper around the traced model in caffe2's pb format.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, predict_net, init_net):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.eval() # always in eval mode
|
| 185 |
+
self._predict_net = predict_net
|
| 186 |
+
self._init_net = init_net
|
| 187 |
+
self._predictor = None
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def predict_net(self):
|
| 191 |
+
"""
|
| 192 |
+
Returns:
|
| 193 |
+
core.Net: the underlying caffe2 predict net
|
| 194 |
+
"""
|
| 195 |
+
return self._predict_net
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def init_net(self):
|
| 199 |
+
"""
|
| 200 |
+
Returns:
|
| 201 |
+
core.Net: the underlying caffe2 init net
|
| 202 |
+
"""
|
| 203 |
+
return self._init_net
|
| 204 |
+
|
| 205 |
+
__init__.__HIDE_SPHINX_DOC__ = True
|
| 206 |
+
|
| 207 |
+
def save_protobuf(self, output_dir):
|
| 208 |
+
"""
|
| 209 |
+
Save the model as caffe2's protobuf format.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
output_dir (str): the output directory to save protobuf files.
|
| 213 |
+
"""
|
| 214 |
+
logger = logging.getLogger(__name__)
|
| 215 |
+
logger.info("Saving model to {} ...".format(output_dir))
|
| 216 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 217 |
+
|
| 218 |
+
with open(os.path.join(output_dir, "model.pb"), "wb") as f:
|
| 219 |
+
f.write(self._predict_net.SerializeToString())
|
| 220 |
+
with open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
|
| 221 |
+
f.write(str(self._predict_net))
|
| 222 |
+
with open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
|
| 223 |
+
f.write(self._init_net.SerializeToString())
|
| 224 |
+
|
| 225 |
+
def save_graph(self, output_file, inputs=None):
|
| 226 |
+
"""
|
| 227 |
+
Save the graph as SVG format.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
output_file (str): a SVG file
|
| 231 |
+
inputs: optional inputs given to the model.
|
| 232 |
+
If given, the inputs will be used to run the graph to record
|
| 233 |
+
shape of every tensor. The shape information will be
|
| 234 |
+
saved together with the graph.
|
| 235 |
+
"""
|
| 236 |
+
if inputs is None:
|
| 237 |
+
save_graph(self._predict_net, output_file, op_only=False)
|
| 238 |
+
else:
|
| 239 |
+
size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
|
| 240 |
+
device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
|
| 241 |
+
inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
|
| 242 |
+
inputs = [x.cpu().numpy() for x in inputs]
|
| 243 |
+
run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def load_protobuf(dir):
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
dir (str): a directory used to save Caffe2Model with
|
| 250 |
+
:meth:`save_protobuf`.
|
| 251 |
+
The files "model.pb" and "model_init.pb" are needed.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Caffe2Model: the caffe2 model loaded from this directory.
|
| 255 |
+
"""
|
| 256 |
+
predict_net = caffe2_pb2.NetDef()
|
| 257 |
+
with open(os.path.join(dir, "model.pb"), "rb") as f:
|
| 258 |
+
predict_net.ParseFromString(f.read())
|
| 259 |
+
|
| 260 |
+
init_net = caffe2_pb2.NetDef()
|
| 261 |
+
with open(os.path.join(dir, "model_init.pb"), "rb") as f:
|
| 262 |
+
init_net.ParseFromString(f.read())
|
| 263 |
+
|
| 264 |
+
return Caffe2Model(predict_net, init_net)
|
| 265 |
+
|
| 266 |
+
def __call__(self, inputs):
|
| 267 |
+
"""
|
| 268 |
+
An interface that wraps around a caffe2 model and mimics detectron2's models'
|
| 269 |
+
input & output format. This is used to compare the outputs of caffe2 model
|
| 270 |
+
with its original torch model.
|
| 271 |
+
|
| 272 |
+
Due to the extra conversion between torch/caffe2,
|
| 273 |
+
this method is not meant for benchmark.
|
| 274 |
+
"""
|
| 275 |
+
if self._predictor is None:
|
| 276 |
+
self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
|
| 277 |
+
return self._predictor(inputs)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/c10.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from detectron2.layers import cat
|
| 8 |
+
from detectron2.layers.roi_align_rotated import ROIAlignRotated
|
| 9 |
+
from detectron2.modeling import poolers
|
| 10 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 11 |
+
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
|
| 12 |
+
from detectron2.structures import Boxes, ImageList, Instances, Keypoints
|
| 13 |
+
|
| 14 |
+
from .shared import alias, to_device
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
This file contains caffe2-compatible implementation of several detectrno2 components.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Caffe2Boxes(Boxes):
|
| 23 |
+
"""
|
| 24 |
+
Representing a list of detectron2.structures.Boxes from minibatch, each box
|
| 25 |
+
is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
|
| 26 |
+
(batch index + 5 coordinates) for RotatedBoxes.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, tensor):
|
| 30 |
+
assert isinstance(tensor, torch.Tensor)
|
| 31 |
+
assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
|
| 32 |
+
# TODO: make tensor immutable when dim is Nx5 for Boxes,
|
| 33 |
+
# and Nx6 for RotatedBoxes?
|
| 34 |
+
self.tensor = tensor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# TODO clean up this class, maybe just extend Instances
|
| 38 |
+
class InstancesList(object):
|
| 39 |
+
"""
|
| 40 |
+
Tensor representation of a list of Instances object for a batch of images.
|
| 41 |
+
|
| 42 |
+
When dealing with a batch of images with Caffe2 ops, a list of bboxes
|
| 43 |
+
(instances) are usually represented by single Tensor with size
|
| 44 |
+
(sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
|
| 45 |
+
for providing common functions to convert between these two representations.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, im_info, indices, extra_fields=None):
|
| 49 |
+
# [N, 3] -> (H, W, Scale)
|
| 50 |
+
self.im_info = im_info
|
| 51 |
+
# [N,] -> indice of batch to which the instance belongs
|
| 52 |
+
self.indices = indices
|
| 53 |
+
# [N, ...]
|
| 54 |
+
self.batch_extra_fields = extra_fields or {}
|
| 55 |
+
|
| 56 |
+
self.image_size = self.im_info
|
| 57 |
+
|
| 58 |
+
def get_fields(self):
|
| 59 |
+
""" like `get_fields` in the Instances object,
|
| 60 |
+
but return each field in tensor representations """
|
| 61 |
+
ret = {}
|
| 62 |
+
for k, v in self.batch_extra_fields.items():
|
| 63 |
+
# if isinstance(v, torch.Tensor):
|
| 64 |
+
# tensor_rep = v
|
| 65 |
+
# elif isinstance(v, (Boxes, Keypoints)):
|
| 66 |
+
# tensor_rep = v.tensor
|
| 67 |
+
# else:
|
| 68 |
+
# raise ValueError("Can't find tensor representation for: {}".format())
|
| 69 |
+
ret[k] = v
|
| 70 |
+
return ret
|
| 71 |
+
|
| 72 |
+
def has(self, name):
|
| 73 |
+
return name in self.batch_extra_fields
|
| 74 |
+
|
| 75 |
+
def set(self, name, value):
|
| 76 |
+
data_len = len(value)
|
| 77 |
+
if len(self.batch_extra_fields):
|
| 78 |
+
assert (
|
| 79 |
+
len(self) == data_len
|
| 80 |
+
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
|
| 81 |
+
self.batch_extra_fields[name] = value
|
| 82 |
+
|
| 83 |
+
def __setattr__(self, name, val):
|
| 84 |
+
if name in ["im_info", "indices", "batch_extra_fields", "image_size"]:
|
| 85 |
+
super().__setattr__(name, val)
|
| 86 |
+
else:
|
| 87 |
+
self.set(name, val)
|
| 88 |
+
|
| 89 |
+
def __getattr__(self, name):
|
| 90 |
+
if name not in self.batch_extra_fields:
|
| 91 |
+
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
|
| 92 |
+
return self.batch_extra_fields[name]
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return len(self.indices)
|
| 96 |
+
|
| 97 |
+
def flatten(self):
|
| 98 |
+
ret = []
|
| 99 |
+
for _, v in self.batch_extra_fields.items():
|
| 100 |
+
if isinstance(v, (Boxes, Keypoints)):
|
| 101 |
+
ret.append(v.tensor)
|
| 102 |
+
else:
|
| 103 |
+
ret.append(v)
|
| 104 |
+
return ret
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def to_d2_instances_list(instances_list):
|
| 108 |
+
"""
|
| 109 |
+
Convert InstancesList to List[Instances]. The input `instances_list` can
|
| 110 |
+
also be a List[Instances], in this case this method is a non-op.
|
| 111 |
+
"""
|
| 112 |
+
if not isinstance(instances_list, InstancesList):
|
| 113 |
+
assert all(isinstance(x, Instances) for x in instances_list)
|
| 114 |
+
return instances_list
|
| 115 |
+
|
| 116 |
+
ret = []
|
| 117 |
+
for i, info in enumerate(instances_list.im_info):
|
| 118 |
+
instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
|
| 119 |
+
|
| 120 |
+
ids = instances_list.indices == i
|
| 121 |
+
for k, v in instances_list.batch_extra_fields.items():
|
| 122 |
+
if isinstance(v, torch.Tensor):
|
| 123 |
+
instances.set(k, v[ids])
|
| 124 |
+
continue
|
| 125 |
+
elif isinstance(v, Boxes):
|
| 126 |
+
instances.set(k, v[ids, -4:])
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
target_type, tensor_source = v
|
| 130 |
+
assert isinstance(tensor_source, torch.Tensor)
|
| 131 |
+
assert tensor_source.shape[0] == instances_list.indices.shape[0]
|
| 132 |
+
tensor_source = tensor_source[ids]
|
| 133 |
+
|
| 134 |
+
if issubclass(target_type, Boxes):
|
| 135 |
+
instances.set(k, Boxes(tensor_source[:, -4:]))
|
| 136 |
+
elif issubclass(target_type, Keypoints):
|
| 137 |
+
instances.set(k, Keypoints(tensor_source))
|
| 138 |
+
elif issubclass(target_type, torch.Tensor):
|
| 139 |
+
instances.set(k, tensor_source)
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError("Can't handle targe type: {}".format(target_type))
|
| 142 |
+
|
| 143 |
+
ret.append(instances)
|
| 144 |
+
return ret
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Caffe2Compatible(object):
|
| 148 |
+
def _get_tensor_mode(self):
|
| 149 |
+
return self._tensor_mode
|
| 150 |
+
|
| 151 |
+
def _set_tensor_mode(self, v):
|
| 152 |
+
self._tensor_mode = v
|
| 153 |
+
|
| 154 |
+
tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
|
| 155 |
+
"""
|
| 156 |
+
If true, the model expects C2-style tensor only inputs/outputs format.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Caffe2RPN(Caffe2Compatible, rpn.RPN):
|
| 161 |
+
def forward(self, images, features, gt_instances=None):
|
| 162 |
+
assert not self.training
|
| 163 |
+
|
| 164 |
+
features = [features[f] for f in self.in_features]
|
| 165 |
+
objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
|
| 166 |
+
|
| 167 |
+
assert isinstance(images, ImageList)
|
| 168 |
+
if self.tensor_mode:
|
| 169 |
+
im_info = images.image_sizes
|
| 170 |
+
else:
|
| 171 |
+
im_info = torch.Tensor(
|
| 172 |
+
[[im_sz[0], im_sz[1], torch.Tensor([1.0])] for im_sz in images.image_sizes]
|
| 173 |
+
).to(images.tensor.device)
|
| 174 |
+
assert isinstance(im_info, torch.Tensor)
|
| 175 |
+
|
| 176 |
+
rpn_rois_list = []
|
| 177 |
+
rpn_roi_probs_list = []
|
| 178 |
+
for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
|
| 179 |
+
objectness_logits_pred,
|
| 180 |
+
anchor_deltas_pred,
|
| 181 |
+
iter(self.anchor_generator.cell_anchors),
|
| 182 |
+
self.anchor_generator.strides,
|
| 183 |
+
):
|
| 184 |
+
scores = scores.detach()
|
| 185 |
+
bbox_deltas = bbox_deltas.detach()
|
| 186 |
+
|
| 187 |
+
rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
|
| 188 |
+
scores,
|
| 189 |
+
bbox_deltas,
|
| 190 |
+
im_info,
|
| 191 |
+
cell_anchors_tensor,
|
| 192 |
+
spatial_scale=1.0 / feat_stride,
|
| 193 |
+
pre_nms_topN=self.pre_nms_topk[self.training],
|
| 194 |
+
post_nms_topN=self.post_nms_topk[self.training],
|
| 195 |
+
nms_thresh=self.nms_thresh,
|
| 196 |
+
min_size=self.min_box_side_len,
|
| 197 |
+
# correct_transform_coords=True, # deprecated argument
|
| 198 |
+
angle_bound_on=True, # Default
|
| 199 |
+
angle_bound_lo=-180,
|
| 200 |
+
angle_bound_hi=180,
|
| 201 |
+
clip_angle_thresh=1.0, # Default
|
| 202 |
+
legacy_plus_one=False,
|
| 203 |
+
)
|
| 204 |
+
rpn_rois_list.append(rpn_rois)
|
| 205 |
+
rpn_roi_probs_list.append(rpn_roi_probs)
|
| 206 |
+
|
| 207 |
+
# For FPN in D2, in RPN all proposals from different levels are concated
|
| 208 |
+
# together, ranked and picked by top post_nms_topk. Then in ROIPooler
|
| 209 |
+
# it calculates level_assignments and calls the RoIAlign from
|
| 210 |
+
# the corresponding level.
|
| 211 |
+
|
| 212 |
+
if len(objectness_logits_pred) == 1:
|
| 213 |
+
rpn_rois = rpn_rois_list[0]
|
| 214 |
+
rpn_roi_probs = rpn_roi_probs_list[0]
|
| 215 |
+
else:
|
| 216 |
+
assert len(rpn_rois_list) == len(rpn_roi_probs_list)
|
| 217 |
+
rpn_post_nms_topN = self.post_nms_topk[self.training]
|
| 218 |
+
|
| 219 |
+
device = rpn_rois_list[0].device
|
| 220 |
+
input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
|
| 221 |
+
|
| 222 |
+
# TODO remove this after confirming rpn_max_level/rpn_min_level
|
| 223 |
+
# is not needed in CollectRpnProposals.
|
| 224 |
+
feature_strides = list(self.anchor_generator.strides)
|
| 225 |
+
rpn_min_level = int(math.log2(feature_strides[0]))
|
| 226 |
+
rpn_max_level = int(math.log2(feature_strides[-1]))
|
| 227 |
+
assert (rpn_max_level - rpn_min_level + 1) == len(
|
| 228 |
+
rpn_rois_list
|
| 229 |
+
), "CollectRpnProposals requires continuous levels"
|
| 230 |
+
|
| 231 |
+
rpn_rois = torch.ops._caffe2.CollectRpnProposals(
|
| 232 |
+
input_list,
|
| 233 |
+
# NOTE: in current implementation, rpn_max_level and rpn_min_level
|
| 234 |
+
# are not needed, only the subtraction of two matters and it
|
| 235 |
+
# can be infer from the number of inputs. Keep them now for
|
| 236 |
+
# consistency.
|
| 237 |
+
rpn_max_level=2 + len(rpn_rois_list) - 1,
|
| 238 |
+
rpn_min_level=2,
|
| 239 |
+
rpn_post_nms_topN=rpn_post_nms_topN,
|
| 240 |
+
)
|
| 241 |
+
rpn_rois = to_device(rpn_rois, device)
|
| 242 |
+
rpn_roi_probs = []
|
| 243 |
+
|
| 244 |
+
proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
|
| 245 |
+
return proposals, {}
|
| 246 |
+
|
| 247 |
+
@staticmethod
|
| 248 |
+
def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
|
| 249 |
+
proposals = InstancesList(
|
| 250 |
+
im_info=im_info,
|
| 251 |
+
indices=rpn_rois[:, 0],
|
| 252 |
+
extra_fields={
|
| 253 |
+
"proposal_boxes": Caffe2Boxes(rpn_rois),
|
| 254 |
+
"objectness_logits": (torch.Tensor, rpn_roi_probs),
|
| 255 |
+
},
|
| 256 |
+
)
|
| 257 |
+
if not tensor_mode:
|
| 258 |
+
proposals = InstancesList.to_d2_instances_list(proposals)
|
| 259 |
+
else:
|
| 260 |
+
proposals = [proposals]
|
| 261 |
+
return proposals
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
|
| 265 |
+
@staticmethod
|
| 266 |
+
def c2_preprocess(box_lists):
|
| 267 |
+
assert all(isinstance(x, Boxes) for x in box_lists)
|
| 268 |
+
if all(isinstance(x, Caffe2Boxes) for x in box_lists):
|
| 269 |
+
# input is pure-tensor based
|
| 270 |
+
assert len(box_lists) == 1
|
| 271 |
+
pooler_fmt_boxes = box_lists[0].tensor
|
| 272 |
+
else:
|
| 273 |
+
pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
|
| 274 |
+
return pooler_fmt_boxes
|
| 275 |
+
|
| 276 |
+
def forward(self, x, box_lists):
|
| 277 |
+
assert not self.training
|
| 278 |
+
|
| 279 |
+
pooler_fmt_boxes = self.c2_preprocess(box_lists)
|
| 280 |
+
num_level_assignments = len(self.level_poolers)
|
| 281 |
+
|
| 282 |
+
if num_level_assignments == 1:
|
| 283 |
+
if isinstance(self.level_poolers[0], ROIAlignRotated):
|
| 284 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 285 |
+
aligned = True
|
| 286 |
+
else:
|
| 287 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 288 |
+
aligned = self.level_poolers[0].aligned
|
| 289 |
+
|
| 290 |
+
out = c2_roi_align(
|
| 291 |
+
x[0],
|
| 292 |
+
pooler_fmt_boxes,
|
| 293 |
+
order="NCHW",
|
| 294 |
+
spatial_scale=float(self.level_poolers[0].spatial_scale),
|
| 295 |
+
pooled_h=int(self.output_size[0]),
|
| 296 |
+
pooled_w=int(self.output_size[1]),
|
| 297 |
+
sampling_ratio=int(self.level_poolers[0].sampling_ratio),
|
| 298 |
+
aligned=aligned,
|
| 299 |
+
)
|
| 300 |
+
return out
|
| 301 |
+
|
| 302 |
+
device = pooler_fmt_boxes.device
|
| 303 |
+
assert (
|
| 304 |
+
self.max_level - self.min_level + 1 == 4
|
| 305 |
+
), "Currently DistributeFpnProposals only support 4 levels"
|
| 306 |
+
fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
|
| 307 |
+
to_device(pooler_fmt_boxes, "cpu"),
|
| 308 |
+
roi_canonical_scale=self.canonical_box_size,
|
| 309 |
+
roi_canonical_level=self.canonical_level,
|
| 310 |
+
roi_max_level=self.max_level,
|
| 311 |
+
roi_min_level=self.min_level,
|
| 312 |
+
legacy_plus_one=False,
|
| 313 |
+
)
|
| 314 |
+
fpn_outputs = [to_device(x, device) for x in fpn_outputs]
|
| 315 |
+
|
| 316 |
+
rois_fpn_list = fpn_outputs[:-1]
|
| 317 |
+
rois_idx_restore_int32 = fpn_outputs[-1]
|
| 318 |
+
|
| 319 |
+
roi_feat_fpn_list = []
|
| 320 |
+
for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
|
| 321 |
+
if isinstance(pooler, ROIAlignRotated):
|
| 322 |
+
c2_roi_align = torch.ops._caffe2.RoIAlignRotated
|
| 323 |
+
aligned = True
|
| 324 |
+
else:
|
| 325 |
+
c2_roi_align = torch.ops._caffe2.RoIAlign
|
| 326 |
+
aligned = bool(pooler.aligned)
|
| 327 |
+
|
| 328 |
+
roi_feat_fpn = c2_roi_align(
|
| 329 |
+
x_level,
|
| 330 |
+
roi_fpn,
|
| 331 |
+
order="NCHW",
|
| 332 |
+
spatial_scale=float(pooler.spatial_scale),
|
| 333 |
+
pooled_h=int(self.output_size[0]),
|
| 334 |
+
pooled_w=int(self.output_size[1]),
|
| 335 |
+
sampling_ratio=int(pooler.sampling_ratio),
|
| 336 |
+
aligned=aligned,
|
| 337 |
+
)
|
| 338 |
+
roi_feat_fpn_list.append(roi_feat_fpn)
|
| 339 |
+
|
| 340 |
+
roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
|
| 341 |
+
roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
|
| 342 |
+
return roi_feat
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class Caffe2FastRCNNOutputsInference:
|
| 346 |
+
def __init__(self, tensor_mode):
|
| 347 |
+
self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
|
| 348 |
+
|
| 349 |
+
def __call__(self, box_predictor, predictions, proposals):
|
| 350 |
+
""" equivalent to FastRCNNOutputLayers.inference """
|
| 351 |
+
score_thresh = box_predictor.test_score_thresh
|
| 352 |
+
nms_thresh = box_predictor.test_nms_thresh
|
| 353 |
+
topk_per_image = box_predictor.test_topk_per_image
|
| 354 |
+
is_rotated = len(box_predictor.box2box_transform.weights) == 5
|
| 355 |
+
|
| 356 |
+
if is_rotated:
|
| 357 |
+
box_dim = 5
|
| 358 |
+
assert box_predictor.box2box_transform.weights[4] == 1, (
|
| 359 |
+
"The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
|
| 360 |
+
+ " thus enforcing the angle weight to be 1 for now"
|
| 361 |
+
)
|
| 362 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
|
| 363 |
+
else:
|
| 364 |
+
box_dim = 4
|
| 365 |
+
box2box_transform_weights = box_predictor.box2box_transform.weights
|
| 366 |
+
|
| 367 |
+
class_logits, box_regression = predictions
|
| 368 |
+
class_prob = F.softmax(class_logits, -1)
|
| 369 |
+
|
| 370 |
+
assert box_regression.shape[1] % box_dim == 0
|
| 371 |
+
cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
|
| 372 |
+
|
| 373 |
+
input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
|
| 374 |
+
|
| 375 |
+
rois = type(proposals[0].proposal_boxes).cat([p.proposal_boxes for p in proposals])
|
| 376 |
+
device, dtype = rois.tensor.device, rois.tensor.dtype
|
| 377 |
+
if input_tensor_mode:
|
| 378 |
+
im_info = proposals[0].image_size
|
| 379 |
+
rois = rois.tensor
|
| 380 |
+
else:
|
| 381 |
+
im_info = torch.Tensor(
|
| 382 |
+
[[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]]
|
| 383 |
+
)
|
| 384 |
+
batch_ids = cat(
|
| 385 |
+
[
|
| 386 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 387 |
+
for i, b in enumerate(len(p) for p in proposals)
|
| 388 |
+
],
|
| 389 |
+
dim=0,
|
| 390 |
+
)
|
| 391 |
+
rois = torch.cat([batch_ids, rois.tensor], dim=1)
|
| 392 |
+
|
| 393 |
+
roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
|
| 394 |
+
to_device(rois, "cpu"),
|
| 395 |
+
to_device(box_regression, "cpu"),
|
| 396 |
+
to_device(im_info, "cpu"),
|
| 397 |
+
weights=box2box_transform_weights,
|
| 398 |
+
apply_scale=True,
|
| 399 |
+
rotated=is_rotated,
|
| 400 |
+
angle_bound_on=True,
|
| 401 |
+
angle_bound_lo=-180,
|
| 402 |
+
angle_bound_hi=180,
|
| 403 |
+
clip_angle_thresh=1.0,
|
| 404 |
+
legacy_plus_one=False,
|
| 405 |
+
)
|
| 406 |
+
roi_pred_bbox = to_device(roi_pred_bbox, device)
|
| 407 |
+
roi_batch_splits = to_device(roi_batch_splits, device)
|
| 408 |
+
|
| 409 |
+
nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
| 410 |
+
to_device(class_prob, "cpu"),
|
| 411 |
+
to_device(roi_pred_bbox, "cpu"),
|
| 412 |
+
to_device(roi_batch_splits, "cpu"),
|
| 413 |
+
score_thresh=float(score_thresh),
|
| 414 |
+
nms=float(nms_thresh),
|
| 415 |
+
detections_per_im=int(topk_per_image),
|
| 416 |
+
soft_nms_enabled=False,
|
| 417 |
+
soft_nms_method="linear",
|
| 418 |
+
soft_nms_sigma=0.5,
|
| 419 |
+
soft_nms_min_score_thres=0.001,
|
| 420 |
+
rotated=is_rotated,
|
| 421 |
+
cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
|
| 422 |
+
input_boxes_include_bg_cls=False,
|
| 423 |
+
output_classes_include_bg_cls=False,
|
| 424 |
+
legacy_plus_one=False,
|
| 425 |
+
)
|
| 426 |
+
roi_score_nms = to_device(nms_outputs[0], device)
|
| 427 |
+
roi_bbox_nms = to_device(nms_outputs[1], device)
|
| 428 |
+
roi_class_nms = to_device(nms_outputs[2], device)
|
| 429 |
+
roi_batch_splits_nms = to_device(nms_outputs[3], device)
|
| 430 |
+
roi_keeps_nms = to_device(nms_outputs[4], device)
|
| 431 |
+
roi_keeps_size_nms = to_device(nms_outputs[5], device)
|
| 432 |
+
if not self.tensor_mode:
|
| 433 |
+
roi_class_nms = roi_class_nms.to(torch.int64)
|
| 434 |
+
|
| 435 |
+
roi_batch_ids = cat(
|
| 436 |
+
[
|
| 437 |
+
torch.full((b, 1), i, dtype=dtype, device=device)
|
| 438 |
+
for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
|
| 439 |
+
],
|
| 440 |
+
dim=0,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
roi_class_nms = alias(roi_class_nms, "class_nms")
|
| 444 |
+
roi_score_nms = alias(roi_score_nms, "score_nms")
|
| 445 |
+
roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
|
| 446 |
+
roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
|
| 447 |
+
roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
|
| 448 |
+
roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
|
| 449 |
+
|
| 450 |
+
results = InstancesList(
|
| 451 |
+
im_info=im_info,
|
| 452 |
+
indices=roi_batch_ids[:, 0],
|
| 453 |
+
extra_fields={
|
| 454 |
+
"pred_boxes": Caffe2Boxes(roi_bbox_nms),
|
| 455 |
+
"scores": roi_score_nms,
|
| 456 |
+
"pred_classes": roi_class_nms,
|
| 457 |
+
},
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if not self.tensor_mode:
|
| 461 |
+
results = InstancesList.to_d2_instances_list(results)
|
| 462 |
+
batch_splits = roi_batch_splits_nms.int().tolist()
|
| 463 |
+
kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
|
| 464 |
+
else:
|
| 465 |
+
results = [results]
|
| 466 |
+
kept_indices = [roi_keeps_nms]
|
| 467 |
+
|
| 468 |
+
return results, kept_indices
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class Caffe2MaskRCNNInference:
|
| 472 |
+
def __call__(self, pred_mask_logits, pred_instances):
|
| 473 |
+
""" equivalent to mask_head.mask_rcnn_inference """
|
| 474 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 475 |
+
assert len(pred_instances) == 1
|
| 476 |
+
mask_probs_pred = pred_mask_logits.sigmoid()
|
| 477 |
+
mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
|
| 478 |
+
pred_instances[0].pred_masks = mask_probs_pred
|
| 479 |
+
else:
|
| 480 |
+
mask_rcnn_inference(pred_mask_logits, pred_instances)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class Caffe2KeypointRCNNInference:
|
| 484 |
+
def __init__(self, use_heatmap_max_keypoint):
|
| 485 |
+
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
| 486 |
+
|
| 487 |
+
def __call__(self, pred_keypoint_logits, pred_instances):
|
| 488 |
+
# just return the keypoint heatmap for now,
|
| 489 |
+
# there will be option to call HeatmapMaxKeypointOp
|
| 490 |
+
output = alias(pred_keypoint_logits, "kps_score")
|
| 491 |
+
if all(isinstance(x, InstancesList) for x in pred_instances):
|
| 492 |
+
assert len(pred_instances) == 1
|
| 493 |
+
if self.use_heatmap_max_keypoint:
|
| 494 |
+
device = output.device
|
| 495 |
+
output = torch.ops._caffe2.HeatmapMaxKeypoint(
|
| 496 |
+
to_device(output, "cpu"),
|
| 497 |
+
pred_instances[0].pred_boxes.tensor,
|
| 498 |
+
should_output_softmax=True, # worth make it configerable?
|
| 499 |
+
)
|
| 500 |
+
output = to_device(output, device)
|
| 501 |
+
output = alias(output, "keypoints_out")
|
| 502 |
+
pred_instances[0].pred_keypoints = output
|
| 503 |
+
return pred_keypoint_logits
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_export.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import io
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List
|
| 8 |
+
import onnx
|
| 9 |
+
import torch
|
| 10 |
+
from caffe2.proto import caffe2_pb2
|
| 11 |
+
from caffe2.python import core
|
| 12 |
+
from caffe2.python.onnx.backend import Caffe2Backend
|
| 13 |
+
from tabulate import tabulate
|
| 14 |
+
from termcolor import colored
|
| 15 |
+
from torch.onnx import OperatorExportTypes
|
| 16 |
+
|
| 17 |
+
from .shared import (
|
| 18 |
+
ScopedWS,
|
| 19 |
+
construct_init_net_from_params,
|
| 20 |
+
fuse_alias_placeholder,
|
| 21 |
+
fuse_copy_between_cpu_and_gpu,
|
| 22 |
+
get_params_from_init_net,
|
| 23 |
+
group_norm_replace_aten_with_caffe2,
|
| 24 |
+
infer_device_type,
|
| 25 |
+
remove_dead_end_ops,
|
| 26 |
+
remove_reshape_for_fc,
|
| 27 |
+
save_graph,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def export_onnx_model(model, inputs):
|
| 34 |
+
"""
|
| 35 |
+
Trace and export a model to onnx format.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model (nn.Module):
|
| 39 |
+
inputs (tuple[args]): the model will be called by `model(*inputs)`
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
an onnx model
|
| 43 |
+
"""
|
| 44 |
+
assert isinstance(model, torch.nn.Module)
|
| 45 |
+
|
| 46 |
+
# make sure all modules are in eval mode, onnx may change the training state
|
| 47 |
+
# of the module if the states are not consistent
|
| 48 |
+
def _check_eval(module):
|
| 49 |
+
assert not module.training
|
| 50 |
+
|
| 51 |
+
model.apply(_check_eval)
|
| 52 |
+
|
| 53 |
+
# Export the model to ONNX
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
with io.BytesIO() as f:
|
| 56 |
+
torch.onnx.export(
|
| 57 |
+
model,
|
| 58 |
+
inputs,
|
| 59 |
+
f,
|
| 60 |
+
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
| 61 |
+
# verbose=True, # NOTE: uncomment this for debugging
|
| 62 |
+
# export_params=True,
|
| 63 |
+
)
|
| 64 |
+
onnx_model = onnx.load_from_string(f.getvalue())
|
| 65 |
+
|
| 66 |
+
# Apply ONNX's Optimization
|
| 67 |
+
all_passes = onnx.optimizer.get_available_passes()
|
| 68 |
+
passes = ["fuse_bn_into_conv"]
|
| 69 |
+
assert all(p in all_passes for p in passes)
|
| 70 |
+
onnx_model = onnx.optimizer.optimize(onnx_model, passes)
|
| 71 |
+
return onnx_model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _op_stats(net_def):
|
| 75 |
+
type_count = {}
|
| 76 |
+
for t in [op.type for op in net_def.op]:
|
| 77 |
+
type_count[t] = type_count.get(t, 0) + 1
|
| 78 |
+
type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
|
| 79 |
+
type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
|
| 80 |
+
return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _assign_device_option(
|
| 84 |
+
predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
ONNX exported network doesn't have concept of device, assign necessary
|
| 88 |
+
device option for each op in order to make it runable on GPU runtime.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def _get_device_type(torch_tensor):
|
| 92 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 93 |
+
assert torch_tensor.device.index == 0
|
| 94 |
+
return torch_tensor.device.type
|
| 95 |
+
|
| 96 |
+
def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
|
| 97 |
+
for op, ssa_i in zip(net_proto.op, net_ssa):
|
| 98 |
+
if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
|
| 99 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 100 |
+
else:
|
| 101 |
+
devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
|
| 102 |
+
assert all(d == devices[0] for d in devices)
|
| 103 |
+
if devices[0] == "cuda":
|
| 104 |
+
op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
|
| 105 |
+
|
| 106 |
+
# update ops in predict_net
|
| 107 |
+
predict_net_input_device_types = {
|
| 108 |
+
(name, 0): _get_device_type(tensor)
|
| 109 |
+
for name, tensor in zip(predict_net.external_input, tensor_inputs)
|
| 110 |
+
}
|
| 111 |
+
predict_net_device_types = infer_device_type(
|
| 112 |
+
predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
|
| 113 |
+
)
|
| 114 |
+
predict_net_ssa, _ = core.get_ssa(predict_net)
|
| 115 |
+
_assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
|
| 116 |
+
|
| 117 |
+
# update ops in init_net
|
| 118 |
+
init_net_ssa, versions = core.get_ssa(init_net)
|
| 119 |
+
init_net_output_device_types = {
|
| 120 |
+
(name, versions[name]): predict_net_device_types[(name, 0)]
|
| 121 |
+
for name in init_net.external_output
|
| 122 |
+
}
|
| 123 |
+
init_net_device_types = infer_device_type(
|
| 124 |
+
init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
|
| 125 |
+
)
|
| 126 |
+
_assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
|
| 130 |
+
"""
|
| 131 |
+
Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
|
| 132 |
+
|
| 133 |
+
Arg:
|
| 134 |
+
model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
|
| 135 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 136 |
+
"""
|
| 137 |
+
model = copy.deepcopy(model)
|
| 138 |
+
assert isinstance(model, torch.nn.Module)
|
| 139 |
+
assert hasattr(model, "encode_additional_info")
|
| 140 |
+
|
| 141 |
+
# Export via ONNX
|
| 142 |
+
logger.info("Exporting a {} model via ONNX ...".format(type(model).__name__))
|
| 143 |
+
onnx_model = export_onnx_model(model, (tensor_inputs,))
|
| 144 |
+
# Convert ONNX model to Caffe2 protobuf
|
| 145 |
+
init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
|
| 146 |
+
ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
|
| 147 |
+
table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
|
| 148 |
+
logger.info(
|
| 149 |
+
"ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Apply protobuf optimization
|
| 153 |
+
fuse_alias_placeholder(predict_net, init_net)
|
| 154 |
+
if any(t.device.type != "cpu" for t in tensor_inputs):
|
| 155 |
+
fuse_copy_between_cpu_and_gpu(predict_net)
|
| 156 |
+
remove_dead_end_ops(init_net)
|
| 157 |
+
_assign_device_option(predict_net, init_net, tensor_inputs)
|
| 158 |
+
params, device_options = get_params_from_init_net(init_net)
|
| 159 |
+
predict_net, params = remove_reshape_for_fc(predict_net, params)
|
| 160 |
+
init_net = construct_init_net_from_params(params, device_options)
|
| 161 |
+
group_norm_replace_aten_with_caffe2(predict_net)
|
| 162 |
+
|
| 163 |
+
# Record necessary information for running the pb model in Detectron2 system.
|
| 164 |
+
model.encode_additional_info(predict_net, init_net)
|
| 165 |
+
|
| 166 |
+
logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
|
| 167 |
+
logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
|
| 168 |
+
|
| 169 |
+
return predict_net, init_net
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
|
| 173 |
+
"""
|
| 174 |
+
Run the caffe2 model on given inputs, recording the shape and draw the graph.
|
| 175 |
+
|
| 176 |
+
predict_net/init_net: caffe2 model.
|
| 177 |
+
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
| 178 |
+
graph_save_path: path for saving graph of exported model.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
|
| 182 |
+
save_graph(predict_net, graph_save_path, op_only=False)
|
| 183 |
+
|
| 184 |
+
# Run the exported Caffe2 net
|
| 185 |
+
logger.info("Running ONNX exported model ...")
|
| 186 |
+
with ScopedWS("__ws_tmp__", True) as ws:
|
| 187 |
+
ws.RunNetOnce(init_net)
|
| 188 |
+
initialized_blobs = set(ws.Blobs())
|
| 189 |
+
uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
|
| 190 |
+
for name, blob in zip(uninitialized, tensor_inputs):
|
| 191 |
+
ws.FeedBlob(name, blob)
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
ws.RunNetOnce(predict_net)
|
| 195 |
+
except RuntimeError as e:
|
| 196 |
+
logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
|
| 197 |
+
|
| 198 |
+
ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
|
| 199 |
+
blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
|
| 200 |
+
|
| 201 |
+
logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
|
| 202 |
+
save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
|
| 203 |
+
|
| 204 |
+
return ws_blobs
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_inference.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from caffe2.proto import caffe2_pb2
|
| 8 |
+
from caffe2.python import core
|
| 9 |
+
|
| 10 |
+
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
|
| 11 |
+
from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ProtobufModel(torch.nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
A class works just like nn.Module in terms of inference, but running
|
| 19 |
+
caffe2 model under the hood. Input/Output are Dict[str, tensor] whose keys
|
| 20 |
+
are in external_input/output.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, predict_net, init_net):
|
| 24 |
+
logger.info("Initializing ProtobufModel ...")
|
| 25 |
+
super().__init__()
|
| 26 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 27 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 28 |
+
self.ws_name = "__ws_tmp__"
|
| 29 |
+
self.net = core.Net(predict_net)
|
| 30 |
+
|
| 31 |
+
with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
|
| 32 |
+
ws.RunNetOnce(init_net)
|
| 33 |
+
for blob in self.net.Proto().external_input:
|
| 34 |
+
if blob not in ws.Blobs():
|
| 35 |
+
ws.CreateBlob(blob)
|
| 36 |
+
ws.CreateNet(self.net)
|
| 37 |
+
|
| 38 |
+
self._error_msgs = set()
|
| 39 |
+
|
| 40 |
+
def forward(self, inputs_dict):
|
| 41 |
+
assert all(inp in self.net.Proto().external_input for inp in inputs_dict)
|
| 42 |
+
with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
|
| 43 |
+
for b, tensor in inputs_dict.items():
|
| 44 |
+
ws.FeedBlob(b, tensor)
|
| 45 |
+
try:
|
| 46 |
+
ws.RunNet(self.net.Proto().name)
|
| 47 |
+
except RuntimeError as e:
|
| 48 |
+
if not str(e) in self._error_msgs:
|
| 49 |
+
self._error_msgs.add(str(e))
|
| 50 |
+
logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
|
| 51 |
+
logger.warning("Catch the error and use partial results.")
|
| 52 |
+
|
| 53 |
+
outputs_dict = collections.OrderedDict(
|
| 54 |
+
[(b, ws.FetchBlob(b)) for b in self.net.Proto().external_output]
|
| 55 |
+
)
|
| 56 |
+
# Remove outputs of current run, this is necessary in order to
|
| 57 |
+
# prevent fetching the result from previous run if the model fails
|
| 58 |
+
# in the middle.
|
| 59 |
+
for b in self.net.Proto().external_output:
|
| 60 |
+
# Needs to create uninitialized blob to make the net runable.
|
| 61 |
+
# This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
|
| 62 |
+
# but there'no such API.
|
| 63 |
+
ws.FeedBlob(b, "{}, a C++ native class of type nullptr (uninitialized).".format(b))
|
| 64 |
+
|
| 65 |
+
return outputs_dict
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ProtobufDetectionModel(torch.nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
A class works just like a pytorch meta arch in terms of inference, but running
|
| 71 |
+
caffe2 model under the hood.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, predict_net, init_net, *, convert_outputs=None):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
predict_net, init_net (core.Net): caffe2 nets
|
| 78 |
+
convert_outptus (callable): a function that converts caffe2
|
| 79 |
+
outputs to the same format of the original pytorch model.
|
| 80 |
+
By default, use the one defined in the caffe2 meta_arch.
|
| 81 |
+
"""
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.protobuf_model = ProtobufModel(predict_net, init_net)
|
| 84 |
+
self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
|
| 85 |
+
self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
|
| 86 |
+
|
| 87 |
+
if convert_outputs is None:
|
| 88 |
+
meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
|
| 89 |
+
meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
|
| 90 |
+
self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
|
| 91 |
+
else:
|
| 92 |
+
self._convert_outputs = convert_outputs
|
| 93 |
+
|
| 94 |
+
def _infer_output_devices(self, inputs_dict):
|
| 95 |
+
def _get_device_type(torch_tensor):
|
| 96 |
+
assert torch_tensor.device.type in ["cpu", "cuda"]
|
| 97 |
+
assert torch_tensor.device.index == 0
|
| 98 |
+
return torch_tensor.device.type
|
| 99 |
+
|
| 100 |
+
predict_net = self.protobuf_model.net.Proto()
|
| 101 |
+
input_device_types = {
|
| 102 |
+
(name, 0): _get_device_type(tensor) for name, tensor in inputs_dict.items()
|
| 103 |
+
}
|
| 104 |
+
device_type_map = infer_device_type(
|
| 105 |
+
predict_net, known_status=input_device_types, device_name_style="pytorch"
|
| 106 |
+
)
|
| 107 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 108 |
+
versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
|
| 109 |
+
output_devices = [device_type_map[outp] for outp in versioned_outputs]
|
| 110 |
+
return output_devices
|
| 111 |
+
|
| 112 |
+
def _convert_inputs(self, batched_inputs):
|
| 113 |
+
# currently all models convert inputs in the same way
|
| 114 |
+
data, im_info = convert_batched_inputs_to_c2_format(
|
| 115 |
+
batched_inputs, self.size_divisibility, self.device
|
| 116 |
+
)
|
| 117 |
+
return {"data": data, "im_info": im_info}
|
| 118 |
+
|
| 119 |
+
def forward(self, batched_inputs):
|
| 120 |
+
c2_inputs = self._convert_inputs(batched_inputs)
|
| 121 |
+
c2_results = self.protobuf_model(c2_inputs)
|
| 122 |
+
|
| 123 |
+
if any(t.device.type != "cpu" for _, t in c2_inputs.items()):
|
| 124 |
+
output_devices = self._infer_output_devices(c2_inputs)
|
| 125 |
+
else:
|
| 126 |
+
output_devices = ["cpu" for _ in self.protobuf_model.net.Proto().external_output]
|
| 127 |
+
|
| 128 |
+
def _cast_caffe2_blob_to_torch_tensor(blob, device):
|
| 129 |
+
return torch.Tensor(blob).to(device) if isinstance(blob, np.ndarray) else None
|
| 130 |
+
|
| 131 |
+
c2_results = {
|
| 132 |
+
name: _cast_caffe2_blob_to_torch_tensor(c2_results[name], device)
|
| 133 |
+
for name, device in zip(self.protobuf_model.net.Proto().external_output, output_devices)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/caffe2_modeling.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import io
|
| 5 |
+
import struct
|
| 6 |
+
import types
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from detectron2.modeling import meta_arch
|
| 10 |
+
from detectron2.modeling.box_regression import Box2BoxTransform
|
| 11 |
+
from detectron2.modeling.meta_arch.panoptic_fpn import combine_semantic_and_instance_outputs
|
| 12 |
+
from detectron2.modeling.postprocessing import detector_postprocess, sem_seg_postprocess
|
| 13 |
+
from detectron2.modeling.roi_heads import keypoint_head
|
| 14 |
+
from detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
|
| 15 |
+
|
| 16 |
+
from .c10 import Caffe2Compatible
|
| 17 |
+
from .patcher import ROIHeadsPatcher, patch_generalized_rcnn
|
| 18 |
+
from .shared import (
|
| 19 |
+
alias,
|
| 20 |
+
check_set_pb_arg,
|
| 21 |
+
get_pb_arg_floats,
|
| 22 |
+
get_pb_arg_valf,
|
| 23 |
+
get_pb_arg_vali,
|
| 24 |
+
get_pb_arg_vals,
|
| 25 |
+
mock_torch_nn_functional_interpolate,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
|
| 30 |
+
"""
|
| 31 |
+
A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
|
| 32 |
+
to detectron2's format (i.e. list of Instances instance).
|
| 33 |
+
This only works when the model follows the Caffe2 detectron's naming convention.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_sizes (List[List[int, int]]): [H, W] of every image.
|
| 37 |
+
tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
|
| 38 |
+
|
| 39 |
+
force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
|
| 40 |
+
if the mask is not found from tensor_outputs (usually due to model crash)
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
results = [Instances(image_size) for image_size in image_sizes]
|
| 44 |
+
|
| 45 |
+
batch_splits = tensor_outputs.get("batch_splits", None)
|
| 46 |
+
if batch_splits:
|
| 47 |
+
raise NotImplementedError()
|
| 48 |
+
assert len(image_sizes) == 1
|
| 49 |
+
result = results[0]
|
| 50 |
+
|
| 51 |
+
bbox_nms = tensor_outputs["bbox_nms"]
|
| 52 |
+
score_nms = tensor_outputs["score_nms"]
|
| 53 |
+
class_nms = tensor_outputs["class_nms"]
|
| 54 |
+
# Detection will always success because Conv support 0-batch
|
| 55 |
+
assert bbox_nms is not None
|
| 56 |
+
assert score_nms is not None
|
| 57 |
+
assert class_nms is not None
|
| 58 |
+
if bbox_nms.shape[1] == 5:
|
| 59 |
+
result.pred_boxes = RotatedBoxes(bbox_nms)
|
| 60 |
+
else:
|
| 61 |
+
result.pred_boxes = Boxes(bbox_nms)
|
| 62 |
+
result.scores = score_nms
|
| 63 |
+
result.pred_classes = class_nms.to(torch.int64)
|
| 64 |
+
|
| 65 |
+
mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
|
| 66 |
+
if mask_fcn_probs is not None:
|
| 67 |
+
# finish the mask pred
|
| 68 |
+
mask_probs_pred = mask_fcn_probs
|
| 69 |
+
num_masks = mask_probs_pred.shape[0]
|
| 70 |
+
class_pred = result.pred_classes
|
| 71 |
+
indices = torch.arange(num_masks, device=class_pred.device)
|
| 72 |
+
mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
|
| 73 |
+
result.pred_masks = mask_probs_pred
|
| 74 |
+
elif force_mask_on:
|
| 75 |
+
# NOTE: there's no way to know the height/width of mask here, it won't be
|
| 76 |
+
# used anyway when batch size is 0, so just set them to 0.
|
| 77 |
+
result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
|
| 78 |
+
|
| 79 |
+
keypoints_out = tensor_outputs.get("keypoints_out", None)
|
| 80 |
+
kps_score = tensor_outputs.get("kps_score", None)
|
| 81 |
+
if keypoints_out is not None:
|
| 82 |
+
# keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
|
| 83 |
+
keypoints_tensor = keypoints_out
|
| 84 |
+
# NOTE: it's possible that prob is not calculated if "should_output_softmax"
|
| 85 |
+
# is set to False in HeatmapMaxKeypoint, so just using raw score, seems
|
| 86 |
+
# it doesn't affect mAP. TODO: check more carefully.
|
| 87 |
+
keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
|
| 88 |
+
result.pred_keypoints = keypoint_xyp
|
| 89 |
+
elif kps_score is not None:
|
| 90 |
+
# keypoint heatmap to sparse data structure
|
| 91 |
+
pred_keypoint_logits = kps_score
|
| 92 |
+
keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
|
| 93 |
+
|
| 94 |
+
return results
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _cast_to_f32(f64):
|
| 98 |
+
return struct.unpack("f", struct.pack("f", f64))[0]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def set_caffe2_compatible_tensor_mode(model, enable=True):
|
| 102 |
+
def _fn(m):
|
| 103 |
+
if isinstance(m, Caffe2Compatible):
|
| 104 |
+
m.tensor_mode = enable
|
| 105 |
+
|
| 106 |
+
model.apply(_fn)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
|
| 110 |
+
"""
|
| 111 |
+
See get_caffe2_inputs() below.
|
| 112 |
+
"""
|
| 113 |
+
assert all(isinstance(x, dict) for x in batched_inputs)
|
| 114 |
+
assert all(x["image"].dim() == 3 for x in batched_inputs)
|
| 115 |
+
|
| 116 |
+
images = [x["image"] for x in batched_inputs]
|
| 117 |
+
images = ImageList.from_tensors(images, size_divisibility)
|
| 118 |
+
|
| 119 |
+
im_info = []
|
| 120 |
+
for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
|
| 121 |
+
target_height = input_per_image.get("height", image_size[0])
|
| 122 |
+
target_width = input_per_image.get("width", image_size[1]) # noqa
|
| 123 |
+
# NOTE: The scale inside im_info is kept as convention and for providing
|
| 124 |
+
# post-processing information if further processing is needed. For
|
| 125 |
+
# current Caffe2 model definitions that don't include post-processing inside
|
| 126 |
+
# the model, this number is not used.
|
| 127 |
+
# NOTE: There can be a slight difference between width and height
|
| 128 |
+
# scales, using a single number can results in numerical difference
|
| 129 |
+
# compared with D2's post-processing.
|
| 130 |
+
scale = target_height / image_size[0]
|
| 131 |
+
im_info.append([image_size[0], image_size[1], scale])
|
| 132 |
+
im_info = torch.Tensor(im_info)
|
| 133 |
+
|
| 134 |
+
return images.tensor.to(device), im_info.to(device)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
Base class for caffe2-compatible implementation of a meta architecture.
|
| 140 |
+
The forward is traceable and its traced graph can be converted to caffe2
|
| 141 |
+
graph through ONNX.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, cfg, torch_model):
|
| 145 |
+
"""
|
| 146 |
+
Args:
|
| 147 |
+
cfg (CfgNode):
|
| 148 |
+
torch_model (nn.Module): the detectron2 model (meta_arch) to be
|
| 149 |
+
converted.
|
| 150 |
+
"""
|
| 151 |
+
super().__init__()
|
| 152 |
+
self._wrapped_model = torch_model
|
| 153 |
+
self.eval()
|
| 154 |
+
set_caffe2_compatible_tensor_mode(self, True)
|
| 155 |
+
|
| 156 |
+
def get_caffe2_inputs(self, batched_inputs):
|
| 157 |
+
"""
|
| 158 |
+
Convert pytorch-style structured inputs to caffe2-style inputs that
|
| 159 |
+
are tuples of tensors.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
batched_inputs (list[dict]): inputs to a detectron2 model
|
| 163 |
+
in its standard format. Each dict has "image" (CHW tensor), and optionally
|
| 164 |
+
"height" and "width".
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
tuple[Tensor]:
|
| 168 |
+
tuple of tensors that will be the inputs to the
|
| 169 |
+
:meth:`forward` method. For existing models, the first
|
| 170 |
+
is an NCHW tensor (padded and batched); the second is
|
| 171 |
+
a im_info Nx3 tensor, where the rows are
|
| 172 |
+
(height, width, unused legacy parameter)
|
| 173 |
+
"""
|
| 174 |
+
return convert_batched_inputs_to_c2_format(
|
| 175 |
+
batched_inputs,
|
| 176 |
+
self._wrapped_model.backbone.size_divisibility,
|
| 177 |
+
self._wrapped_model.device,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 181 |
+
"""
|
| 182 |
+
Save extra metadata that will be used by inference in the output protobuf.
|
| 183 |
+
"""
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
def forward(self, inputs):
|
| 187 |
+
"""
|
| 188 |
+
Run the forward in caffe2-style. It has to use caffe2-compatible ops
|
| 189 |
+
and the method will be used for tracing.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
|
| 193 |
+
They will be the inputs of the converted caffe2 graph.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
tuple[Tensor]: output tensors. They will be the outputs of the
|
| 197 |
+
converted caffe2 graph.
|
| 198 |
+
"""
|
| 199 |
+
raise NotImplementedError
|
| 200 |
+
|
| 201 |
+
def _caffe2_preprocess_image(self, inputs):
|
| 202 |
+
"""
|
| 203 |
+
Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
|
| 204 |
+
It normalizes the input images, and the final caffe2 graph assumes the
|
| 205 |
+
inputs have been batched already.
|
| 206 |
+
"""
|
| 207 |
+
data, im_info = inputs
|
| 208 |
+
data = alias(data, "data")
|
| 209 |
+
im_info = alias(im_info, "im_info")
|
| 210 |
+
mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
|
| 211 |
+
normalized_data = (data - mean) / std
|
| 212 |
+
normalized_data = alias(normalized_data, "normalized_data")
|
| 213 |
+
|
| 214 |
+
# Pack (data, im_info) into ImageList which is recognized by self.inference.
|
| 215 |
+
images = ImageList(tensor=normalized_data, image_sizes=im_info)
|
| 216 |
+
return images
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def get_outputs_converter(predict_net, init_net):
|
| 220 |
+
"""
|
| 221 |
+
Creates a function that converts outputs of the caffe2 model to
|
| 222 |
+
detectron2's standard format.
|
| 223 |
+
The function uses information in `predict_net` and `init_net` that are
|
| 224 |
+
available at inferene time. Therefore the function logic can be used in inference.
|
| 225 |
+
|
| 226 |
+
The returned function has the following signature:
|
| 227 |
+
|
| 228 |
+
def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
|
| 229 |
+
|
| 230 |
+
Where
|
| 231 |
+
|
| 232 |
+
* batched_inputs (list[dict]): the original input format of the meta arch
|
| 233 |
+
* c2_inputs (dict[str, Tensor]): the caffe2 inputs.
|
| 234 |
+
* c2_results (dict[str, Tensor]): the caffe2 output format,
|
| 235 |
+
corresponding to the outputs of the :meth:`forward` function.
|
| 236 |
+
* detectron2_outputs: the original output format of the meta arch.
|
| 237 |
+
|
| 238 |
+
This function can be used to compare the outputs of the original meta arch and
|
| 239 |
+
the converted caffe2 graph.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
callable: a callable of the above signature.
|
| 243 |
+
"""
|
| 244 |
+
raise NotImplementedError
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Caffe2GeneralizedRCNN(Caffe2MetaArch):
|
| 248 |
+
def __init__(self, cfg, torch_model):
|
| 249 |
+
assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
|
| 250 |
+
torch_model = patch_generalized_rcnn(torch_model)
|
| 251 |
+
super().__init__(cfg, torch_model)
|
| 252 |
+
|
| 253 |
+
self.roi_heads_patcher = ROIHeadsPatcher(cfg, self._wrapped_model.roi_heads)
|
| 254 |
+
|
| 255 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 256 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 257 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 258 |
+
check_set_pb_arg(
|
| 259 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 260 |
+
)
|
| 261 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
|
| 262 |
+
|
| 263 |
+
@mock_torch_nn_functional_interpolate()
|
| 264 |
+
def forward(self, inputs):
|
| 265 |
+
if not self.tensor_mode:
|
| 266 |
+
return self._wrapped_model.inference(inputs)
|
| 267 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 268 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 269 |
+
proposals, _ = self._wrapped_model.proposal_generator(images, features)
|
| 270 |
+
with self.roi_heads_patcher.mock_roi_heads():
|
| 271 |
+
detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
|
| 272 |
+
return tuple(detector_results[0].flatten())
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def get_outputs_converter(predict_net, init_net):
|
| 276 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 277 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]]
|
| 278 |
+
results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
|
| 279 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 280 |
+
|
| 281 |
+
return f
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class Caffe2PanopticFPN(Caffe2MetaArch):
|
| 285 |
+
def __init__(self, cfg, torch_model):
|
| 286 |
+
assert isinstance(torch_model, meta_arch.PanopticFPN)
|
| 287 |
+
torch_model = patch_generalized_rcnn(torch_model)
|
| 288 |
+
super().__init__(cfg, torch_model)
|
| 289 |
+
|
| 290 |
+
self.roi_heads_patcher = ROIHeadsPatcher(cfg, self._wrapped_model.roi_heads)
|
| 291 |
+
|
| 292 |
+
@mock_torch_nn_functional_interpolate()
|
| 293 |
+
def forward(self, inputs):
|
| 294 |
+
assert self.tensor_mode
|
| 295 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 296 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 297 |
+
|
| 298 |
+
sem_seg_results, _ = self._wrapped_model.sem_seg_head(features)
|
| 299 |
+
sem_seg_results = alias(sem_seg_results, "sem_seg")
|
| 300 |
+
|
| 301 |
+
proposals, _ = self._wrapped_model.proposal_generator(images, features)
|
| 302 |
+
|
| 303 |
+
with self.roi_heads_patcher.mock_roi_heads(self.tensor_mode):
|
| 304 |
+
detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
|
| 305 |
+
|
| 306 |
+
return tuple(detector_results[0].flatten()) + (sem_seg_results,)
|
| 307 |
+
|
| 308 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 309 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 310 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 311 |
+
check_set_pb_arg(
|
| 312 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 313 |
+
)
|
| 314 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"PanopticFPN")
|
| 315 |
+
|
| 316 |
+
# Inference parameters:
|
| 317 |
+
check_set_pb_arg(predict_net, "combine_on", "i", self._wrapped_model.combine_on)
|
| 318 |
+
check_set_pb_arg(
|
| 319 |
+
predict_net,
|
| 320 |
+
"combine_overlap_threshold",
|
| 321 |
+
"f",
|
| 322 |
+
_cast_to_f32(self._wrapped_model.combine_overlap_threshold),
|
| 323 |
+
)
|
| 324 |
+
check_set_pb_arg(
|
| 325 |
+
predict_net,
|
| 326 |
+
"combine_stuff_area_limit",
|
| 327 |
+
"i",
|
| 328 |
+
self._wrapped_model.combine_stuff_area_limit,
|
| 329 |
+
)
|
| 330 |
+
check_set_pb_arg(
|
| 331 |
+
predict_net,
|
| 332 |
+
"combine_instances_confidence_threshold",
|
| 333 |
+
"f",
|
| 334 |
+
_cast_to_f32(self._wrapped_model.combine_instances_confidence_threshold),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def get_outputs_converter(predict_net, init_net):
|
| 339 |
+
combine_on = get_pb_arg_vali(predict_net, "combine_on", None)
|
| 340 |
+
combine_overlap_threshold = get_pb_arg_valf(predict_net, "combine_overlap_threshold", None)
|
| 341 |
+
combine_stuff_area_limit = get_pb_arg_vali(predict_net, "combine_stuff_area_limit", None)
|
| 342 |
+
combine_instances_confidence_threshold = get_pb_arg_valf(
|
| 343 |
+
predict_net, "combine_instances_confidence_threshold", None
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 347 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]]
|
| 348 |
+
detector_results = assemble_rcnn_outputs_by_name(
|
| 349 |
+
image_sizes, c2_results, force_mask_on=True
|
| 350 |
+
)
|
| 351 |
+
sem_seg_results = c2_results["sem_seg"]
|
| 352 |
+
|
| 353 |
+
# copied from meta_arch/panoptic_fpn.py ...
|
| 354 |
+
processed_results = []
|
| 355 |
+
for sem_seg_result, detector_result, input_per_image, image_size in zip(
|
| 356 |
+
sem_seg_results, detector_results, batched_inputs, image_sizes
|
| 357 |
+
):
|
| 358 |
+
height = input_per_image.get("height", image_size[0])
|
| 359 |
+
width = input_per_image.get("width", image_size[1])
|
| 360 |
+
sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width)
|
| 361 |
+
detector_r = detector_postprocess(detector_result, height, width)
|
| 362 |
+
|
| 363 |
+
processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r})
|
| 364 |
+
|
| 365 |
+
if combine_on:
|
| 366 |
+
panoptic_r = combine_semantic_and_instance_outputs(
|
| 367 |
+
detector_r,
|
| 368 |
+
sem_seg_r.argmax(dim=0),
|
| 369 |
+
combine_overlap_threshold,
|
| 370 |
+
combine_stuff_area_limit,
|
| 371 |
+
combine_instances_confidence_threshold,
|
| 372 |
+
)
|
| 373 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
| 374 |
+
return processed_results
|
| 375 |
+
|
| 376 |
+
return f
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class Caffe2RetinaNet(Caffe2MetaArch):
|
| 380 |
+
def __init__(self, cfg, torch_model):
|
| 381 |
+
assert isinstance(torch_model, meta_arch.RetinaNet)
|
| 382 |
+
super().__init__(cfg, torch_model)
|
| 383 |
+
|
| 384 |
+
@mock_torch_nn_functional_interpolate()
|
| 385 |
+
def forward(self, inputs):
|
| 386 |
+
assert self.tensor_mode
|
| 387 |
+
images = self._caffe2_preprocess_image(inputs)
|
| 388 |
+
|
| 389 |
+
# explicitly return the images sizes to avoid removing "im_info" by ONNX
|
| 390 |
+
# since it's not used in the forward path
|
| 391 |
+
return_tensors = [images.image_sizes]
|
| 392 |
+
|
| 393 |
+
features = self._wrapped_model.backbone(images.tensor)
|
| 394 |
+
features = [features[f] for f in self._wrapped_model.in_features]
|
| 395 |
+
for i, feature_i in enumerate(features):
|
| 396 |
+
features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
|
| 397 |
+
return_tensors.append(features[i])
|
| 398 |
+
|
| 399 |
+
box_cls, box_delta = self._wrapped_model.head(features)
|
| 400 |
+
for i, (box_cls_i, box_delta_i) in enumerate(zip(box_cls, box_delta)):
|
| 401 |
+
return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
|
| 402 |
+
return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
|
| 403 |
+
|
| 404 |
+
return tuple(return_tensors)
|
| 405 |
+
|
| 406 |
+
def encode_additional_info(self, predict_net, init_net):
|
| 407 |
+
size_divisibility = self._wrapped_model.backbone.size_divisibility
|
| 408 |
+
check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
|
| 409 |
+
check_set_pb_arg(
|
| 410 |
+
predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
|
| 411 |
+
)
|
| 412 |
+
check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
|
| 413 |
+
|
| 414 |
+
# Inference parameters:
|
| 415 |
+
check_set_pb_arg(
|
| 416 |
+
predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.score_threshold)
|
| 417 |
+
)
|
| 418 |
+
check_set_pb_arg(predict_net, "topk_candidates", "i", self._wrapped_model.topk_candidates)
|
| 419 |
+
check_set_pb_arg(
|
| 420 |
+
predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.nms_threshold)
|
| 421 |
+
)
|
| 422 |
+
check_set_pb_arg(
|
| 423 |
+
predict_net,
|
| 424 |
+
"max_detections_per_image",
|
| 425 |
+
"i",
|
| 426 |
+
self._wrapped_model.max_detections_per_image,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
check_set_pb_arg(
|
| 430 |
+
predict_net,
|
| 431 |
+
"bbox_reg_weights",
|
| 432 |
+
"floats",
|
| 433 |
+
[_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
|
| 434 |
+
)
|
| 435 |
+
self._encode_anchor_generator_cfg(predict_net)
|
| 436 |
+
|
| 437 |
+
def _encode_anchor_generator_cfg(self, predict_net):
|
| 438 |
+
# serialize anchor_generator for future use
|
| 439 |
+
serialized_anchor_generator = io.BytesIO()
|
| 440 |
+
torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
|
| 441 |
+
# Ideally we can put anchor generating inside the model, then we don't
|
| 442 |
+
# need to store this information.
|
| 443 |
+
bytes = serialized_anchor_generator.getvalue()
|
| 444 |
+
check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def get_outputs_converter(predict_net, init_net):
|
| 448 |
+
self = types.SimpleNamespace()
|
| 449 |
+
serialized_anchor_generator = io.BytesIO(
|
| 450 |
+
get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
|
| 451 |
+
)
|
| 452 |
+
self.anchor_generator = torch.load(serialized_anchor_generator)
|
| 453 |
+
bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
|
| 454 |
+
self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
|
| 455 |
+
self.score_threshold = get_pb_arg_valf(predict_net, "score_threshold", None)
|
| 456 |
+
self.topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
|
| 457 |
+
self.nms_threshold = get_pb_arg_valf(predict_net, "nms_threshold", None)
|
| 458 |
+
self.max_detections_per_image = get_pb_arg_vali(
|
| 459 |
+
predict_net, "max_detections_per_image", None
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# hack to reuse inference code from RetinaNet
|
| 463 |
+
self.inference = functools.partial(meta_arch.RetinaNet.inference, self)
|
| 464 |
+
self.inference_single_image = functools.partial(
|
| 465 |
+
meta_arch.RetinaNet.inference_single_image, self
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
def f(batched_inputs, c2_inputs, c2_results):
|
| 469 |
+
image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]]
|
| 470 |
+
|
| 471 |
+
num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
|
| 472 |
+
box_cls = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
|
| 473 |
+
box_delta = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
|
| 474 |
+
|
| 475 |
+
# For each feature level, feature should have the same batch size and
|
| 476 |
+
# spatial dimension as the box_cls and box_delta.
|
| 477 |
+
dummy_features = [box_delta[i].clone()[:, 0:0, :, :] for i in range(num_features)]
|
| 478 |
+
anchors = self.anchor_generator(dummy_features)
|
| 479 |
+
|
| 480 |
+
# self.num_classess can be inferred
|
| 481 |
+
self.num_classes = box_cls[0].shape[1] // (box_delta[0].shape[1] // 4)
|
| 482 |
+
|
| 483 |
+
results = self.inference(box_cls, box_delta, anchors, image_sizes)
|
| 484 |
+
return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
|
| 485 |
+
|
| 486 |
+
return f
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
|
| 490 |
+
"GeneralizedRCNN": Caffe2GeneralizedRCNN,
|
| 491 |
+
"PanopticFPN": Caffe2PanopticFPN,
|
| 492 |
+
"RetinaNet": Caffe2RetinaNet,
|
| 493 |
+
}
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/patcher.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import mock
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from detectron2.modeling import poolers
|
| 8 |
+
from detectron2.modeling.proposal_generator import rpn
|
| 9 |
+
from detectron2.modeling.roi_heads import keypoint_head, mask_head
|
| 10 |
+
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
| 11 |
+
|
| 12 |
+
from .c10 import (
|
| 13 |
+
Caffe2Compatible,
|
| 14 |
+
Caffe2FastRCNNOutputsInference,
|
| 15 |
+
Caffe2KeypointRCNNInference,
|
| 16 |
+
Caffe2MaskRCNNInference,
|
| 17 |
+
Caffe2ROIPooler,
|
| 18 |
+
Caffe2RPN,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GenericMixin(object):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Caffe2CompatibleConverter(object):
|
| 27 |
+
"""
|
| 28 |
+
A GenericUpdater which implements the `create_from` interface, by modifying
|
| 29 |
+
module object and assign it with another class replaceCls.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, replaceCls):
|
| 33 |
+
self.replaceCls = replaceCls
|
| 34 |
+
|
| 35 |
+
def create_from(self, module):
|
| 36 |
+
# update module's class to the new class
|
| 37 |
+
assert isinstance(module, torch.nn.Module)
|
| 38 |
+
if issubclass(self.replaceCls, GenericMixin):
|
| 39 |
+
# replaceCls should act as mixin, create a new class on-the-fly
|
| 40 |
+
new_class = type(
|
| 41 |
+
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
| 42 |
+
(self.replaceCls, module.__class__),
|
| 43 |
+
{}, # {"new_method": lambda self: ...},
|
| 44 |
+
)
|
| 45 |
+
module.__class__ = new_class
|
| 46 |
+
else:
|
| 47 |
+
# replaceCls is complete class, this allow arbitrary class swap
|
| 48 |
+
module.__class__ = self.replaceCls
|
| 49 |
+
|
| 50 |
+
# initialize Caffe2Compatible
|
| 51 |
+
if isinstance(module, Caffe2Compatible):
|
| 52 |
+
module.tensor_mode = False
|
| 53 |
+
|
| 54 |
+
return module
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def patch(model, target, updater, *args, **kwargs):
|
| 58 |
+
"""
|
| 59 |
+
recursively (post-order) update all modules with the target type and its
|
| 60 |
+
subclasses, make a initialization/composition/inheritance/... via the
|
| 61 |
+
updater.create_from.
|
| 62 |
+
"""
|
| 63 |
+
for name, module in model.named_children():
|
| 64 |
+
model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
| 65 |
+
if isinstance(model, target):
|
| 66 |
+
return updater.create_from(model, *args, **kwargs)
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def patch_generalized_rcnn(model):
|
| 71 |
+
ccc = Caffe2CompatibleConverter
|
| 72 |
+
model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
| 73 |
+
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
| 74 |
+
|
| 75 |
+
return model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@contextlib.contextmanager
|
| 79 |
+
def mock_fastrcnn_outputs_inference(
|
| 80 |
+
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
| 81 |
+
):
|
| 82 |
+
with mock.patch.object(
|
| 83 |
+
box_predictor_type,
|
| 84 |
+
"inference",
|
| 85 |
+
autospec=True,
|
| 86 |
+
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
| 87 |
+
) as mocked_func:
|
| 88 |
+
yield
|
| 89 |
+
if check:
|
| 90 |
+
assert mocked_func.call_count > 0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@contextlib.contextmanager
|
| 94 |
+
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
| 95 |
+
with mock.patch(
|
| 96 |
+
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
| 97 |
+
) as mocked_func:
|
| 98 |
+
yield
|
| 99 |
+
if check:
|
| 100 |
+
assert mocked_func.call_count > 0
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@contextlib.contextmanager
|
| 104 |
+
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
| 105 |
+
with mock.patch(
|
| 106 |
+
"{}.keypoint_rcnn_inference".format(patched_module),
|
| 107 |
+
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
| 108 |
+
) as mocked_func:
|
| 109 |
+
yield
|
| 110 |
+
if check:
|
| 111 |
+
assert mocked_func.call_count > 0
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ROIHeadsPatcher:
|
| 115 |
+
def __init__(self, cfg, heads):
|
| 116 |
+
self.heads = heads
|
| 117 |
+
|
| 118 |
+
self.use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
|
| 119 |
+
|
| 120 |
+
@contextlib.contextmanager
|
| 121 |
+
def mock_roi_heads(self, tensor_mode=True):
|
| 122 |
+
"""
|
| 123 |
+
Patching several inference functions inside ROIHeads and its subclasses
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
| 127 |
+
format or not. Default to True.
|
| 128 |
+
"""
|
| 129 |
+
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
|
| 130 |
+
# are called inside the same file as BaseXxxHead due to using mock.patch.
|
| 131 |
+
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
| 132 |
+
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
| 133 |
+
|
| 134 |
+
mock_ctx_managers = [
|
| 135 |
+
mock_fastrcnn_outputs_inference(
|
| 136 |
+
tensor_mode=tensor_mode,
|
| 137 |
+
check=True,
|
| 138 |
+
box_predictor_type=type(self.heads.box_predictor),
|
| 139 |
+
)
|
| 140 |
+
]
|
| 141 |
+
if getattr(self.heads, "keypoint_on", False):
|
| 142 |
+
mock_ctx_managers += [
|
| 143 |
+
mock_keypoint_rcnn_inference(
|
| 144 |
+
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
| 145 |
+
)
|
| 146 |
+
]
|
| 147 |
+
if getattr(self.heads, "mask_on", False):
|
| 148 |
+
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
| 149 |
+
|
| 150 |
+
with contextlib.ExitStack() as stack: # python 3.3+
|
| 151 |
+
for mgr in mock_ctx_managers:
|
| 152 |
+
stack.enter_context(mgr)
|
| 153 |
+
yield
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/export/shared.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import contextlib
|
| 5 |
+
import copy
|
| 6 |
+
import functools
|
| 7 |
+
import logging
|
| 8 |
+
import mock
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 12 |
+
import caffe2.python.utils as putils
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from caffe2.proto import caffe2_pb2
|
| 16 |
+
from caffe2.python import core, net_drawer, workspace
|
| 17 |
+
from torch.nn.functional import interpolate as interp
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ==== torch/utils_toffee/cast.py =======================================
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def to_device(t, device_str):
|
| 26 |
+
"""
|
| 27 |
+
This function is a replacement of .to(another_device) such that it allows the
|
| 28 |
+
casting to be traced properly by explicitly calling the underlying copy ops.
|
| 29 |
+
It also avoids introducing unncessary op when casting to the same device.
|
| 30 |
+
"""
|
| 31 |
+
src = t.device
|
| 32 |
+
dst = torch.device(device_str)
|
| 33 |
+
|
| 34 |
+
if src == dst:
|
| 35 |
+
return t
|
| 36 |
+
elif src.type == "cuda" and dst.type == "cpu":
|
| 37 |
+
return torch.ops._caffe2.CopyGPUToCPU(t)
|
| 38 |
+
elif src.type == "cpu" and dst.type == "cuda":
|
| 39 |
+
return torch.ops._caffe2.CopyCPUToGPU(t)
|
| 40 |
+
else:
|
| 41 |
+
raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ==== torch/utils_toffee/interpolate.py =======================================
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
|
| 48 |
+
def BilinearInterpolation(tensor_in, up_scale):
|
| 49 |
+
assert up_scale % 2 == 0, "Scale should be even"
|
| 50 |
+
|
| 51 |
+
def upsample_filt(size):
|
| 52 |
+
factor = (size + 1) // 2
|
| 53 |
+
if size % 2 == 1:
|
| 54 |
+
center = factor - 1
|
| 55 |
+
else:
|
| 56 |
+
center = factor - 0.5
|
| 57 |
+
|
| 58 |
+
og = np.ogrid[:size, :size]
|
| 59 |
+
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
| 60 |
+
|
| 61 |
+
kernel_size = int(up_scale) * 2
|
| 62 |
+
bil_filt = upsample_filt(kernel_size)
|
| 63 |
+
|
| 64 |
+
dim = int(tensor_in.shape[1])
|
| 65 |
+
kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
|
| 66 |
+
kernel[range(dim), range(dim), :, :] = bil_filt
|
| 67 |
+
|
| 68 |
+
tensor_out = F.conv_transpose2d(
|
| 69 |
+
tensor_in,
|
| 70 |
+
weight=to_device(torch.Tensor(kernel), tensor_in.device),
|
| 71 |
+
bias=None,
|
| 72 |
+
stride=int(up_scale),
|
| 73 |
+
padding=int(up_scale / 2),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return tensor_out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
|
| 80 |
+
# using dynamic `scale_factor` rather than static `size`. (T43166860)
|
| 81 |
+
# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
|
| 82 |
+
def onnx_compatibale_interpolate(
|
| 83 |
+
input, size=None, scale_factor=None, mode="nearest", align_corners=None
|
| 84 |
+
):
|
| 85 |
+
# NOTE: The input dimensions are interpreted in the form:
|
| 86 |
+
# `mini-batch x channels x [optional depth] x [optional height] x width`.
|
| 87 |
+
if size is None and scale_factor is not None:
|
| 88 |
+
if input.dim() == 4:
|
| 89 |
+
if isinstance(scale_factor, (int, float)):
|
| 90 |
+
height_scale, width_scale = (scale_factor, scale_factor)
|
| 91 |
+
else:
|
| 92 |
+
assert isinstance(scale_factor, (tuple, list))
|
| 93 |
+
assert len(scale_factor) == 2
|
| 94 |
+
height_scale, width_scale = scale_factor
|
| 95 |
+
|
| 96 |
+
assert not align_corners, "No matching C2 op for align_corners == True"
|
| 97 |
+
if mode == "nearest":
|
| 98 |
+
return torch.ops._caffe2.ResizeNearest(
|
| 99 |
+
input, order="NCHW", width_scale=width_scale, height_scale=height_scale
|
| 100 |
+
)
|
| 101 |
+
elif mode == "bilinear":
|
| 102 |
+
logger.warning(
|
| 103 |
+
"Use F.conv_transpose2d for bilinear interpolate"
|
| 104 |
+
" because there's no such C2 op, this may cause significant"
|
| 105 |
+
" slowdown and the boundary pixels won't be as same as"
|
| 106 |
+
" using F.interpolate due to padding."
|
| 107 |
+
)
|
| 108 |
+
assert height_scale == width_scale
|
| 109 |
+
return BilinearInterpolation(input, up_scale=height_scale)
|
| 110 |
+
logger.warning("Output size is not static, it might cause ONNX conversion issue")
|
| 111 |
+
|
| 112 |
+
return interp(input, size, scale_factor, mode, align_corners)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@contextlib.contextmanager
|
| 116 |
+
def mock_torch_nn_functional_interpolate():
|
| 117 |
+
if torch.onnx.is_in_onnx_export():
|
| 118 |
+
with mock.patch(
|
| 119 |
+
"torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
|
| 120 |
+
):
|
| 121 |
+
yield
|
| 122 |
+
else:
|
| 123 |
+
yield
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ==== torch/utils_caffe2/ws_utils.py ==========================================
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ScopedWS(object):
|
| 130 |
+
def __init__(self, ws_name, is_reset, is_cleanup=False):
|
| 131 |
+
self.ws_name = ws_name
|
| 132 |
+
self.is_reset = is_reset
|
| 133 |
+
self.is_cleanup = is_cleanup
|
| 134 |
+
self.org_ws = ""
|
| 135 |
+
|
| 136 |
+
def __enter__(self):
|
| 137 |
+
self.org_ws = workspace.CurrentWorkspace()
|
| 138 |
+
if self.ws_name is not None:
|
| 139 |
+
workspace.SwitchWorkspace(self.ws_name, True)
|
| 140 |
+
if self.is_reset:
|
| 141 |
+
workspace.ResetWorkspace()
|
| 142 |
+
|
| 143 |
+
return workspace
|
| 144 |
+
|
| 145 |
+
def __exit__(self, *args):
|
| 146 |
+
if self.is_cleanup:
|
| 147 |
+
workspace.ResetWorkspace()
|
| 148 |
+
if self.ws_name is not None:
|
| 149 |
+
workspace.SwitchWorkspace(self.org_ws)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def fetch_any_blob(name):
|
| 153 |
+
bb = None
|
| 154 |
+
try:
|
| 155 |
+
bb = workspace.FetchBlob(name)
|
| 156 |
+
except TypeError:
|
| 157 |
+
bb = workspace.FetchInt8Blob(name)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error("Get blob {} error: {}".format(name, e))
|
| 160 |
+
|
| 161 |
+
return bb
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ==== torch/utils_caffe2/protobuf.py ==========================================
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_pb_arg(pb, arg_name):
|
| 168 |
+
for x in pb.arg:
|
| 169 |
+
if x.name == arg_name:
|
| 170 |
+
return x
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_pb_arg_valf(pb, arg_name, default_val):
|
| 175 |
+
arg = get_pb_arg(pb, arg_name)
|
| 176 |
+
return arg.f if arg is not None else default_val
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_pb_arg_floats(pb, arg_name, default_val):
|
| 180 |
+
arg = get_pb_arg(pb, arg_name)
|
| 181 |
+
return list(map(float, arg.floats)) if arg is not None else default_val
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_pb_arg_ints(pb, arg_name, default_val):
|
| 185 |
+
arg = get_pb_arg(pb, arg_name)
|
| 186 |
+
return list(map(int, arg.ints)) if arg is not None else default_val
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_pb_arg_vali(pb, arg_name, default_val):
|
| 190 |
+
arg = get_pb_arg(pb, arg_name)
|
| 191 |
+
return arg.i if arg is not None else default_val
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_pb_arg_vals(pb, arg_name, default_val):
|
| 195 |
+
arg = get_pb_arg(pb, arg_name)
|
| 196 |
+
return arg.s if arg is not None else default_val
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_pb_arg_valstrings(pb, arg_name, default_val):
|
| 200 |
+
arg = get_pb_arg(pb, arg_name)
|
| 201 |
+
return list(arg.strings) if arg is not None else default_val
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
|
| 205 |
+
arg = get_pb_arg(pb, arg_name)
|
| 206 |
+
if arg is None:
|
| 207 |
+
arg = putils.MakeArgument(arg_name, arg_value)
|
| 208 |
+
assert hasattr(arg, arg_attr)
|
| 209 |
+
pb.arg.extend([arg])
|
| 210 |
+
if allow_override and getattr(arg, arg_attr) != arg_value:
|
| 211 |
+
logger.warning(
|
| 212 |
+
"Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
|
| 213 |
+
)
|
| 214 |
+
setattr(arg, arg_attr, arg_value)
|
| 215 |
+
else:
|
| 216 |
+
assert arg is not None
|
| 217 |
+
assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
|
| 218 |
+
getattr(arg, arg_attr), arg_value
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
|
| 223 |
+
assert type(tensor) == np.ndarray
|
| 224 |
+
kTypeNameMapper = {
|
| 225 |
+
np.dtype("float32"): "GivenTensorFill",
|
| 226 |
+
np.dtype("int32"): "GivenTensorIntFill",
|
| 227 |
+
np.dtype("int64"): "GivenTensorInt64Fill",
|
| 228 |
+
np.dtype("uint8"): "GivenTensorStringFill",
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
args_dict = {}
|
| 232 |
+
if tensor.dtype == np.dtype("uint8"):
|
| 233 |
+
args_dict.update({"values": [str(tensor.data)], "shape": [1]})
|
| 234 |
+
else:
|
| 235 |
+
args_dict.update({"values": tensor, "shape": tensor.shape})
|
| 236 |
+
|
| 237 |
+
if device_option is not None:
|
| 238 |
+
args_dict["device_option"] = device_option
|
| 239 |
+
|
| 240 |
+
return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
|
| 244 |
+
assert type(int8_tensor) == workspace.Int8Tensor
|
| 245 |
+
kTypeNameMapper = {
|
| 246 |
+
np.dtype("int32"): "Int8GivenIntTensorFill",
|
| 247 |
+
np.dtype("uint8"): "Int8GivenTensorFill",
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
tensor = int8_tensor.data
|
| 251 |
+
assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
|
| 252 |
+
values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
|
| 253 |
+
|
| 254 |
+
return core.CreateOperator(
|
| 255 |
+
kTypeNameMapper[tensor.dtype],
|
| 256 |
+
[],
|
| 257 |
+
[name],
|
| 258 |
+
values=values,
|
| 259 |
+
shape=tensor.shape,
|
| 260 |
+
Y_scale=int8_tensor.scale,
|
| 261 |
+
Y_zero_point=int8_tensor.zero_point,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def create_const_fill_op(
|
| 266 |
+
name: str,
|
| 267 |
+
blob: Union[np.ndarray, workspace.Int8Tensor],
|
| 268 |
+
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
| 269 |
+
) -> caffe2_pb2.OperatorDef:
|
| 270 |
+
"""
|
| 271 |
+
Given a blob object, return the Caffe2 operator that creates this blob
|
| 272 |
+
as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
tensor_type = type(blob)
|
| 276 |
+
assert tensor_type in [
|
| 277 |
+
np.ndarray,
|
| 278 |
+
workspace.Int8Tensor,
|
| 279 |
+
], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
|
| 280 |
+
name, type(blob)
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if tensor_type == np.ndarray:
|
| 284 |
+
return _create_const_fill_op_from_numpy(name, blob, device_option)
|
| 285 |
+
elif tensor_type == workspace.Int8Tensor:
|
| 286 |
+
assert device_option is None
|
| 287 |
+
return _create_const_fill_op_from_c2_int8_tensor(name, blob)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def construct_init_net_from_params(
|
| 291 |
+
params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
|
| 292 |
+
) -> caffe2_pb2.NetDef:
|
| 293 |
+
"""
|
| 294 |
+
Construct the init_net from params dictionary
|
| 295 |
+
"""
|
| 296 |
+
init_net = caffe2_pb2.NetDef()
|
| 297 |
+
device_options = device_options or {}
|
| 298 |
+
for name, blob in params.items():
|
| 299 |
+
if isinstance(blob, str):
|
| 300 |
+
logger.warning(
|
| 301 |
+
(
|
| 302 |
+
"Blob {} with type {} is not supported in generating init net,"
|
| 303 |
+
" skipped.".format(name, type(blob))
|
| 304 |
+
)
|
| 305 |
+
)
|
| 306 |
+
continue
|
| 307 |
+
init_net.op.extend(
|
| 308 |
+
[create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
|
| 309 |
+
)
|
| 310 |
+
init_net.external_output.append(name)
|
| 311 |
+
return init_net
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_producer_map(ssa):
|
| 315 |
+
"""
|
| 316 |
+
Return dict from versioned blob to (i, j),
|
| 317 |
+
where i is index of producer op, j is the index of output of that op.
|
| 318 |
+
"""
|
| 319 |
+
producer_map = {}
|
| 320 |
+
for i in range(len(ssa)):
|
| 321 |
+
outputs = ssa[i][1]
|
| 322 |
+
for j, outp in enumerate(outputs):
|
| 323 |
+
producer_map[outp] = (i, j)
|
| 324 |
+
return producer_map
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_consumer_map(ssa):
|
| 328 |
+
"""
|
| 329 |
+
Return dict from versioned blob to list of (i, j),
|
| 330 |
+
where i is index of consumer op, j is the index of input of that op.
|
| 331 |
+
"""
|
| 332 |
+
consumer_map = collections.defaultdict(list)
|
| 333 |
+
for i in range(len(ssa)):
|
| 334 |
+
inputs = ssa[i][0]
|
| 335 |
+
for j, inp in enumerate(inputs):
|
| 336 |
+
consumer_map[inp].append((i, j))
|
| 337 |
+
return consumer_map
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_params_from_init_net(
|
| 341 |
+
init_net: caffe2_pb2.NetDef,
|
| 342 |
+
) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
|
| 343 |
+
"""
|
| 344 |
+
Take the output blobs from init_net by running it.
|
| 345 |
+
Outputs:
|
| 346 |
+
params: dict from blob name to numpy array
|
| 347 |
+
device_options: dict from blob name to the device option of its creating op
|
| 348 |
+
"""
|
| 349 |
+
# NOTE: this assumes that the params is determined by producer op with the
|
| 350 |
+
# only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
|
| 351 |
+
def _get_device_option(producer_op):
|
| 352 |
+
if producer_op.type == "CopyGPUToCPU":
|
| 353 |
+
return caffe2_pb2.DeviceOption()
|
| 354 |
+
else:
|
| 355 |
+
return producer_op.device_option
|
| 356 |
+
|
| 357 |
+
with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
|
| 358 |
+
ws.RunNetOnce(init_net)
|
| 359 |
+
params = {b: fetch_any_blob(b) for b in init_net.external_output}
|
| 360 |
+
ssa, versions = core.get_ssa(init_net)
|
| 361 |
+
producer_map = get_producer_map(ssa)
|
| 362 |
+
device_options = {
|
| 363 |
+
b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
|
| 364 |
+
for b in init_net.external_output
|
| 365 |
+
}
|
| 366 |
+
return params, device_options
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _updater_raise(op, input_types, output_types):
|
| 370 |
+
raise RuntimeError(
|
| 371 |
+
"Failed to apply updater for op {} given input_types {} and"
|
| 372 |
+
" output_types {}".format(op, input_types, output_types)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _generic_status_identifier(
|
| 377 |
+
predict_net: caffe2_pb2.NetDef,
|
| 378 |
+
status_updater: Callable,
|
| 379 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 380 |
+
) -> Dict[Tuple[str, int], Any]:
|
| 381 |
+
"""
|
| 382 |
+
Statically infer the status of each blob, the status can be such as device type
|
| 383 |
+
(CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
|
| 384 |
+
is versioned blob (Tuple[str, int]) in the format compatible with ssa.
|
| 385 |
+
Inputs:
|
| 386 |
+
predict_net: the caffe2 network
|
| 387 |
+
status_updater: a callable, given an op and the status of its input/output,
|
| 388 |
+
it returns the updated status of input/output. `None` is used for
|
| 389 |
+
representing unknown status.
|
| 390 |
+
known_status: a dict containing known status, used as initialization.
|
| 391 |
+
Outputs:
|
| 392 |
+
A dict mapping from versioned blob to its status
|
| 393 |
+
"""
|
| 394 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 395 |
+
versioned_ext_input = [(b, 0) for b in predict_net.external_input]
|
| 396 |
+
versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
|
| 397 |
+
all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
|
| 398 |
+
|
| 399 |
+
allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
|
| 400 |
+
assert all(k in allowed_vbs for k in known_status)
|
| 401 |
+
assert all(v is not None for v in known_status.values())
|
| 402 |
+
_known_status = copy.deepcopy(known_status)
|
| 403 |
+
|
| 404 |
+
def _check_and_update(key, value):
|
| 405 |
+
assert value is not None
|
| 406 |
+
if key in _known_status:
|
| 407 |
+
if not _known_status[key] == value:
|
| 408 |
+
raise RuntimeError(
|
| 409 |
+
"Confilict status for {}, existing status {}, new status {}".format(
|
| 410 |
+
key, _known_status[key], value
|
| 411 |
+
)
|
| 412 |
+
)
|
| 413 |
+
_known_status[key] = value
|
| 414 |
+
|
| 415 |
+
def _update_i(op, ssa_i):
|
| 416 |
+
versioned_inputs = ssa_i[0]
|
| 417 |
+
versioned_outputs = ssa_i[1]
|
| 418 |
+
|
| 419 |
+
inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
|
| 420 |
+
outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
|
| 421 |
+
|
| 422 |
+
new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
|
| 423 |
+
|
| 424 |
+
for versioned_blob, status in zip(
|
| 425 |
+
versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
|
| 426 |
+
):
|
| 427 |
+
if status is not None:
|
| 428 |
+
_check_and_update(versioned_blob, status)
|
| 429 |
+
|
| 430 |
+
for op, ssa_i in zip(predict_net.op, ssa):
|
| 431 |
+
_update_i(op, ssa_i)
|
| 432 |
+
for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
|
| 433 |
+
_update_i(op, ssa_i)
|
| 434 |
+
|
| 435 |
+
# NOTE: This strictly checks all the blob from predict_net must be assgined
|
| 436 |
+
# a known status. However sometimes it's impossible (eg. having deadend op),
|
| 437 |
+
# we may relax this constraint if
|
| 438 |
+
for k in all_versioned_blobs:
|
| 439 |
+
if k not in _known_status:
|
| 440 |
+
raise NotImplementedError(
|
| 441 |
+
"Can not infer the status for {}. Currently only support the case where"
|
| 442 |
+
" a single forward and backward pass can identify status for all blobs.".format(k)
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
return _known_status
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def infer_device_type(
|
| 449 |
+
predict_net: caffe2_pb2.NetDef,
|
| 450 |
+
known_status: Dict[Tuple[str, int], Any],
|
| 451 |
+
device_name_style: str = "caffe2",
|
| 452 |
+
) -> Dict[Tuple[str, int], str]:
|
| 453 |
+
""" Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob """
|
| 454 |
+
|
| 455 |
+
assert device_name_style in ["caffe2", "pytorch"]
|
| 456 |
+
_CPU_STR = "cpu"
|
| 457 |
+
_GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
|
| 458 |
+
|
| 459 |
+
def _copy_cpu_to_gpu_updater(op, input_types, output_types):
|
| 460 |
+
if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
|
| 461 |
+
_updater_raise(op, input_types, output_types)
|
| 462 |
+
return ([_CPU_STR], [_GPU_STR])
|
| 463 |
+
|
| 464 |
+
def _copy_gpu_to_cpu_updater(op, input_types, output_types):
|
| 465 |
+
if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
|
| 466 |
+
_updater_raise(op, input_types, output_types)
|
| 467 |
+
return ([_GPU_STR], [_CPU_STR])
|
| 468 |
+
|
| 469 |
+
def _other_ops_updater(op, input_types, output_types):
|
| 470 |
+
non_none_types = [x for x in input_types + output_types if x is not None]
|
| 471 |
+
if len(non_none_types) > 0:
|
| 472 |
+
the_type = non_none_types[0]
|
| 473 |
+
if not all(x == the_type for x in non_none_types):
|
| 474 |
+
_updater_raise(op, input_types, output_types)
|
| 475 |
+
else:
|
| 476 |
+
the_type = None
|
| 477 |
+
return ([the_type for _ in op.input], [the_type for _ in op.output])
|
| 478 |
+
|
| 479 |
+
def _device_updater(op, *args, **kwargs):
|
| 480 |
+
return {
|
| 481 |
+
"CopyCPUToGPU": _copy_cpu_to_gpu_updater,
|
| 482 |
+
"CopyGPUToCPU": _copy_gpu_to_cpu_updater,
|
| 483 |
+
}.get(op.type, _other_ops_updater)(op, *args, **kwargs)
|
| 484 |
+
|
| 485 |
+
return _generic_status_identifier(predict_net, _device_updater, known_status)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# ==== torch/utils_caffe2/vis.py ===============================================
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def _modify_blob_names(ops, blob_rename_f):
|
| 492 |
+
ret = []
|
| 493 |
+
|
| 494 |
+
def _replace_list(blob_list, replaced_list):
|
| 495 |
+
del blob_list[:]
|
| 496 |
+
blob_list.extend(replaced_list)
|
| 497 |
+
|
| 498 |
+
for x in ops:
|
| 499 |
+
cur = copy.deepcopy(x)
|
| 500 |
+
_replace_list(cur.input, list(map(blob_rename_f, cur.input)))
|
| 501 |
+
_replace_list(cur.output, list(map(blob_rename_f, cur.output)))
|
| 502 |
+
ret.append(cur)
|
| 503 |
+
|
| 504 |
+
return ret
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _rename_blob(name, blob_sizes, blob_ranges):
|
| 508 |
+
def _list_to_str(bsize):
|
| 509 |
+
ret = ", ".join([str(x) for x in bsize])
|
| 510 |
+
ret = "[" + ret + "]"
|
| 511 |
+
return ret
|
| 512 |
+
|
| 513 |
+
ret = name
|
| 514 |
+
if blob_sizes is not None and name in blob_sizes:
|
| 515 |
+
ret += "\n" + _list_to_str(blob_sizes[name])
|
| 516 |
+
if blob_ranges is not None and name in blob_ranges:
|
| 517 |
+
ret += "\n" + _list_to_str(blob_ranges[name])
|
| 518 |
+
|
| 519 |
+
return ret
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# graph_name could not contain word 'graph'
|
| 523 |
+
def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
|
| 524 |
+
blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
|
| 525 |
+
return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
|
| 529 |
+
graph = None
|
| 530 |
+
ops = net.op
|
| 531 |
+
if blob_rename_func is not None:
|
| 532 |
+
ops = _modify_blob_names(ops, blob_rename_func)
|
| 533 |
+
if not op_only:
|
| 534 |
+
graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
|
| 535 |
+
else:
|
| 536 |
+
graph = net_drawer.GetPydotGraphMinimal(
|
| 537 |
+
ops, graph_name, rankdir="TB", minimal_dependency=True
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
try:
|
| 541 |
+
par_dir = os.path.dirname(file_name)
|
| 542 |
+
if not os.path.exists(par_dir):
|
| 543 |
+
os.makedirs(par_dir)
|
| 544 |
+
|
| 545 |
+
format = os.path.splitext(os.path.basename(file_name))[-1]
|
| 546 |
+
if format == ".png":
|
| 547 |
+
graph.write_png(file_name)
|
| 548 |
+
elif format == ".pdf":
|
| 549 |
+
graph.write_pdf(file_name)
|
| 550 |
+
elif format == ".svg":
|
| 551 |
+
graph.write_svg(file_name)
|
| 552 |
+
else:
|
| 553 |
+
print("Incorrect format {}".format(format))
|
| 554 |
+
except Exception as e:
|
| 555 |
+
print("Error when writing graph to image {}".format(e))
|
| 556 |
+
|
| 557 |
+
return graph
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
# ==== torch/utils_toffee/aten_to_caffe2.py ====================================
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
|
| 564 |
+
"""
|
| 565 |
+
For ONNX exported model, GroupNorm will be represented as ATen op,
|
| 566 |
+
this can be a drop in replacement from ATen to GroupNorm
|
| 567 |
+
"""
|
| 568 |
+
count = 0
|
| 569 |
+
for op in predict_net.op:
|
| 570 |
+
if op.type == "ATen":
|
| 571 |
+
op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
|
| 572 |
+
if op_name and op_name.decode() == "group_norm":
|
| 573 |
+
op.arg.remove(get_pb_arg(op, "operator"))
|
| 574 |
+
|
| 575 |
+
if get_pb_arg_vali(op, "cudnn_enabled", None):
|
| 576 |
+
op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
|
| 577 |
+
|
| 578 |
+
num_groups = get_pb_arg_vali(op, "num_groups", None)
|
| 579 |
+
if num_groups is not None:
|
| 580 |
+
op.arg.remove(get_pb_arg(op, "num_groups"))
|
| 581 |
+
check_set_pb_arg(op, "group", "i", num_groups)
|
| 582 |
+
|
| 583 |
+
op.type = "GroupNorm"
|
| 584 |
+
count += 1
|
| 585 |
+
if count > 1:
|
| 586 |
+
logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
# ==== torch/utils_toffee/alias.py =============================================
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def alias(x, name, is_backward=False):
|
| 593 |
+
if not torch.onnx.is_in_onnx_export():
|
| 594 |
+
return x
|
| 595 |
+
assert isinstance(x, torch.Tensor)
|
| 596 |
+
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def fuse_alias_placeholder(predict_net, init_net):
|
| 600 |
+
""" Remove AliasWithName placeholder and rename the input/output of it """
|
| 601 |
+
# First we finish all the re-naming
|
| 602 |
+
for i, op in enumerate(predict_net.op):
|
| 603 |
+
if op.type == "AliasWithName":
|
| 604 |
+
assert len(op.input) == 1
|
| 605 |
+
assert len(op.output) == 1
|
| 606 |
+
name = get_pb_arg_vals(op, "name", None).decode()
|
| 607 |
+
is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
|
| 608 |
+
rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
|
| 609 |
+
rename_op_output(predict_net, i, 0, name)
|
| 610 |
+
|
| 611 |
+
# Remove AliasWithName, should be very safe since it's a non-op
|
| 612 |
+
new_ops = []
|
| 613 |
+
for op in predict_net.op:
|
| 614 |
+
if op.type != "AliasWithName":
|
| 615 |
+
new_ops.append(op)
|
| 616 |
+
else:
|
| 617 |
+
# safety check
|
| 618 |
+
assert op.input == op.output
|
| 619 |
+
assert op.input[0] == op.arg[0].s.decode()
|
| 620 |
+
del predict_net.op[:]
|
| 621 |
+
predict_net.op.extend(new_ops)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# ==== torch/utils_caffe2/graph_transform.py ===================================
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class IllegalGraphTransformError(ValueError):
|
| 628 |
+
""" When a graph transform function call can't be executed. """
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def _rename_versioned_blob_in_proto(
|
| 632 |
+
proto: caffe2_pb2.NetDef,
|
| 633 |
+
old_name: str,
|
| 634 |
+
new_name: str,
|
| 635 |
+
version: int,
|
| 636 |
+
ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
|
| 637 |
+
start_versions: Dict[str, int],
|
| 638 |
+
end_versions: Dict[str, int],
|
| 639 |
+
):
|
| 640 |
+
""" In given proto, rename all blobs with matched version """
|
| 641 |
+
# Operater list
|
| 642 |
+
for op, i_th_ssa in zip(proto.op, ssa):
|
| 643 |
+
versioned_inputs, versioned_outputs = i_th_ssa
|
| 644 |
+
for i in range(len(op.input)):
|
| 645 |
+
if versioned_inputs[i] == (old_name, version):
|
| 646 |
+
op.input[i] = new_name
|
| 647 |
+
for i in range(len(op.output)):
|
| 648 |
+
if versioned_outputs[i] == (old_name, version):
|
| 649 |
+
op.output[i] = new_name
|
| 650 |
+
# external_input
|
| 651 |
+
if start_versions.get(old_name, 0) == version:
|
| 652 |
+
for i in range(len(proto.external_input)):
|
| 653 |
+
if proto.external_input[i] == old_name:
|
| 654 |
+
proto.external_input[i] = new_name
|
| 655 |
+
# external_output
|
| 656 |
+
if end_versions.get(old_name, 0) == version:
|
| 657 |
+
for i in range(len(proto.external_output)):
|
| 658 |
+
if proto.external_output[i] == old_name:
|
| 659 |
+
proto.external_output[i] = new_name
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def rename_op_input(
|
| 663 |
+
predict_net: caffe2_pb2.NetDef,
|
| 664 |
+
init_net: caffe2_pb2.NetDef,
|
| 665 |
+
op_id: int,
|
| 666 |
+
input_id: int,
|
| 667 |
+
new_name: str,
|
| 668 |
+
from_producer: bool = False,
|
| 669 |
+
):
|
| 670 |
+
"""
|
| 671 |
+
Rename the op_id-th operator in predict_net, change it's input_id-th input's
|
| 672 |
+
name to the new_name. It also does automatic re-route and change
|
| 673 |
+
external_input and init_net if necessary.
|
| 674 |
+
- It requires the input is only consumed by this op.
|
| 675 |
+
- This function modifies predict_net and init_net in-place.
|
| 676 |
+
- When from_producer is enable, this also updates other operators that consumes
|
| 677 |
+
the same input. Be cautious because may trigger unintended behavior.
|
| 678 |
+
"""
|
| 679 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 680 |
+
assert isinstance(init_net, caffe2_pb2.NetDef)
|
| 681 |
+
|
| 682 |
+
init_net_ssa, init_net_versions = core.get_ssa(init_net)
|
| 683 |
+
predict_net_ssa, predict_net_versions = core.get_ssa(
|
| 684 |
+
predict_net, copy.deepcopy(init_net_versions)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
|
| 688 |
+
old_name, version = versioned_inputs[input_id]
|
| 689 |
+
|
| 690 |
+
if from_producer:
|
| 691 |
+
producer_map = get_producer_map(predict_net_ssa)
|
| 692 |
+
if not (old_name, version) in producer_map:
|
| 693 |
+
raise NotImplementedError(
|
| 694 |
+
"Can't find producer, the input {} is probably from"
|
| 695 |
+
" init_net, this is not supported yet.".format(old_name)
|
| 696 |
+
)
|
| 697 |
+
producer = producer_map[(old_name, version)]
|
| 698 |
+
rename_op_output(predict_net, producer[0], producer[1], new_name)
|
| 699 |
+
return
|
| 700 |
+
|
| 701 |
+
def contain_targets(op_ssa):
|
| 702 |
+
return (old_name, version) in op_ssa[0]
|
| 703 |
+
|
| 704 |
+
is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
|
| 705 |
+
if sum(is_consumer) > 1:
|
| 706 |
+
raise IllegalGraphTransformError(
|
| 707 |
+
(
|
| 708 |
+
"Input '{}' of operator(#{}) are consumed by other ops, please use"
|
| 709 |
+
+ " rename_op_output on the producer instead. Offending op: \n{}"
|
| 710 |
+
).format(old_name, op_id, predict_net.op[op_id])
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# update init_net
|
| 714 |
+
_rename_versioned_blob_in_proto(
|
| 715 |
+
init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
|
| 716 |
+
)
|
| 717 |
+
# update predict_net
|
| 718 |
+
_rename_versioned_blob_in_proto(
|
| 719 |
+
predict_net,
|
| 720 |
+
old_name,
|
| 721 |
+
new_name,
|
| 722 |
+
version,
|
| 723 |
+
predict_net_ssa,
|
| 724 |
+
init_net_versions,
|
| 725 |
+
predict_net_versions,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
|
| 730 |
+
"""
|
| 731 |
+
Rename the op_id-th operator in predict_net, change it's output_id-th input's
|
| 732 |
+
name to the new_name. It also does automatic re-route and change
|
| 733 |
+
external_output and if necessary.
|
| 734 |
+
- It allows multiple consumers of its output.
|
| 735 |
+
- This function modifies predict_net in-place, doesn't need init_net.
|
| 736 |
+
"""
|
| 737 |
+
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
| 738 |
+
|
| 739 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 740 |
+
|
| 741 |
+
versioned_inputs, versioned_outputs = ssa[op_id]
|
| 742 |
+
old_name, version = versioned_outputs[output_id]
|
| 743 |
+
|
| 744 |
+
# update predict_net
|
| 745 |
+
_rename_versioned_blob_in_proto(
|
| 746 |
+
predict_net, old_name, new_name, version, ssa, {}, blob_versions
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def get_sub_graph_external_input_output(
|
| 751 |
+
predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
|
| 752 |
+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
| 753 |
+
"""
|
| 754 |
+
Return the list of external input/output of sub-graph,
|
| 755 |
+
each element is tuple of the name and corresponding version in predict_net.
|
| 756 |
+
|
| 757 |
+
external input/output is defined the same way as caffe2 NetDef.
|
| 758 |
+
"""
|
| 759 |
+
ssa, versions = core.get_ssa(predict_net)
|
| 760 |
+
|
| 761 |
+
all_inputs = []
|
| 762 |
+
all_outputs = []
|
| 763 |
+
for op_id in sub_graph_op_indices:
|
| 764 |
+
all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
|
| 765 |
+
all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
|
| 766 |
+
|
| 767 |
+
# for versioned blobs, external inputs are just those blob in all_inputs
|
| 768 |
+
# but not in all_outputs
|
| 769 |
+
ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
|
| 770 |
+
|
| 771 |
+
# external outputs are essentially outputs of this subgraph that are used
|
| 772 |
+
# outside of this sub-graph (including predict_net.external_output)
|
| 773 |
+
all_other_inputs = sum(
|
| 774 |
+
(ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
|
| 775 |
+
[(outp, versions[outp]) for outp in predict_net.external_output],
|
| 776 |
+
)
|
| 777 |
+
ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
|
| 778 |
+
|
| 779 |
+
return ext_inputs, ext_outputs
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class DiGraph:
|
| 783 |
+
""" A DAG representation of caffe2 graph, each vertice is a versioned blob. """
|
| 784 |
+
|
| 785 |
+
def __init__(self):
|
| 786 |
+
self.vertices = set()
|
| 787 |
+
self.graph = collections.defaultdict(list)
|
| 788 |
+
|
| 789 |
+
def add_edge(self, u, v):
|
| 790 |
+
self.graph[u].append(v)
|
| 791 |
+
self.vertices.add(u)
|
| 792 |
+
self.vertices.add(v)
|
| 793 |
+
|
| 794 |
+
# grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
|
| 795 |
+
def get_all_paths(self, s, d):
|
| 796 |
+
visited = {k: False for k in self.vertices}
|
| 797 |
+
path = []
|
| 798 |
+
all_paths = []
|
| 799 |
+
|
| 800 |
+
def _get_all_paths_util(graph, u, d, visited, path):
|
| 801 |
+
visited[u] = True
|
| 802 |
+
path.append(u)
|
| 803 |
+
if u == d:
|
| 804 |
+
all_paths.append(copy.deepcopy(path))
|
| 805 |
+
else:
|
| 806 |
+
for i in graph[u]:
|
| 807 |
+
if not visited[i]:
|
| 808 |
+
_get_all_paths_util(graph, i, d, visited, path)
|
| 809 |
+
path.pop()
|
| 810 |
+
visited[u] = False
|
| 811 |
+
|
| 812 |
+
_get_all_paths_util(self.graph, s, d, visited, path)
|
| 813 |
+
return all_paths
|
| 814 |
+
|
| 815 |
+
@staticmethod
|
| 816 |
+
def from_ssa(ssa):
|
| 817 |
+
graph = DiGraph()
|
| 818 |
+
for op_id in range(len(ssa)):
|
| 819 |
+
for inp in ssa[op_id][0]:
|
| 820 |
+
for outp in ssa[op_id][1]:
|
| 821 |
+
graph.add_edge(inp, outp)
|
| 822 |
+
return graph
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def _get_dependency_chain(ssa, versioned_target, versioned_source):
|
| 826 |
+
"""
|
| 827 |
+
Return the index list of relevant operator to produce target blob from source blob,
|
| 828 |
+
if there's no dependency, return empty list.
|
| 829 |
+
"""
|
| 830 |
+
|
| 831 |
+
# finding all paths between nodes can be O(N!), thus we can only search
|
| 832 |
+
# in the subgraph using the op starting from the first consumer of source blob
|
| 833 |
+
# to the producer of the target blob.
|
| 834 |
+
consumer_map = get_consumer_map(ssa)
|
| 835 |
+
producer_map = get_producer_map(ssa)
|
| 836 |
+
start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
|
| 837 |
+
end_op = (
|
| 838 |
+
producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
|
| 839 |
+
)
|
| 840 |
+
sub_graph_ssa = ssa[start_op : end_op + 1]
|
| 841 |
+
if len(sub_graph_ssa) > 30:
|
| 842 |
+
logger.warning(
|
| 843 |
+
"Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
|
| 844 |
+
" might take non-trival time to find all paths between them.".format(
|
| 845 |
+
versioned_source, versioned_target, start_op, end_op
|
| 846 |
+
)
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
dag = DiGraph.from_ssa(sub_graph_ssa)
|
| 850 |
+
paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
|
| 851 |
+
ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
|
| 852 |
+
return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
|
| 856 |
+
"""
|
| 857 |
+
Idenfity the reshape sub-graph in a protobuf.
|
| 858 |
+
The reshape sub-graph is defined as matching the following pattern:
|
| 859 |
+
|
| 860 |
+
(input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
|
| 861 |
+
└-------------------------------------------> Reshape -> (output_blob)
|
| 862 |
+
|
| 863 |
+
Return:
|
| 864 |
+
List of sub-graphs, each sub-graph is represented as a list of indices
|
| 865 |
+
of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
|
| 866 |
+
"""
|
| 867 |
+
|
| 868 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 869 |
+
|
| 870 |
+
ret = []
|
| 871 |
+
for i, op in enumerate(predict_net.op):
|
| 872 |
+
if op.type == "Reshape":
|
| 873 |
+
assert len(op.input) == 2
|
| 874 |
+
input_ssa = ssa[i][0]
|
| 875 |
+
data_source = input_ssa[0]
|
| 876 |
+
shape_source = input_ssa[1]
|
| 877 |
+
op_indices = _get_dependency_chain(ssa, shape_source, data_source)
|
| 878 |
+
ret.append(op_indices + [i])
|
| 879 |
+
return ret
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def remove_reshape_for_fc(predict_net, params):
|
| 883 |
+
"""
|
| 884 |
+
In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
|
| 885 |
+
a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
|
| 886 |
+
doesn't work well with ONNX and Int8 tools, and cause using extra
|
| 887 |
+
ops (eg. ExpandDims) that might not be available on mobile.
|
| 888 |
+
Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
|
| 889 |
+
after exporting ONNX model.
|
| 890 |
+
"""
|
| 891 |
+
from caffe2.python import core
|
| 892 |
+
|
| 893 |
+
# find all reshape sub-graph that can be removed, which is now all Reshape
|
| 894 |
+
# sub-graph whose output is only consumed by FC.
|
| 895 |
+
# TODO: to make it safer, we may need the actually value to better determine
|
| 896 |
+
# if a Reshape before FC is removable.
|
| 897 |
+
reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
|
| 898 |
+
sub_graphs_to_remove = []
|
| 899 |
+
for reshape_sub_graph in reshape_sub_graphs:
|
| 900 |
+
reshape_op_id = reshape_sub_graph[-1]
|
| 901 |
+
assert predict_net.op[reshape_op_id].type == "Reshape"
|
| 902 |
+
ssa, _ = core.get_ssa(predict_net)
|
| 903 |
+
reshape_output = ssa[reshape_op_id][1][0]
|
| 904 |
+
consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
|
| 905 |
+
if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
|
| 906 |
+
# safety check if the sub-graph is isolated, for this reshape sub-graph,
|
| 907 |
+
# it means it has one non-param external input and one external output.
|
| 908 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(
|
| 909 |
+
predict_net, reshape_sub_graph
|
| 910 |
+
)
|
| 911 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 912 |
+
if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
|
| 913 |
+
sub_graphs_to_remove.append(reshape_sub_graph)
|
| 914 |
+
|
| 915 |
+
# perform removing subgraph by:
|
| 916 |
+
# 1: rename the Reshape's output to its input, then the graph can be
|
| 917 |
+
# seen as in-place itentify, meaning whose external input/output are the same.
|
| 918 |
+
# 2: simply remove those ops.
|
| 919 |
+
remove_op_ids = []
|
| 920 |
+
params_to_remove = []
|
| 921 |
+
for sub_graph in sub_graphs_to_remove:
|
| 922 |
+
logger.info(
|
| 923 |
+
"Remove Reshape sub-graph:\n{}".format(
|
| 924 |
+
"".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
|
| 925 |
+
)
|
| 926 |
+
)
|
| 927 |
+
reshape_op_id = sub_graph[-1]
|
| 928 |
+
new_reshap_output = predict_net.op[reshape_op_id].input[0]
|
| 929 |
+
rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
|
| 930 |
+
ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
|
| 931 |
+
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
| 932 |
+
params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
|
| 933 |
+
assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
|
| 934 |
+
assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
|
| 935 |
+
assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
|
| 936 |
+
remove_op_ids.extend(sub_graph)
|
| 937 |
+
params_to_remove.extend(params_ext_inputs)
|
| 938 |
+
|
| 939 |
+
predict_net = copy.deepcopy(predict_net)
|
| 940 |
+
new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
|
| 941 |
+
del predict_net.op[:]
|
| 942 |
+
predict_net.op.extend(new_ops)
|
| 943 |
+
for versioned_params in params_to_remove:
|
| 944 |
+
name = versioned_params[0]
|
| 945 |
+
logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
|
| 946 |
+
del params[name]
|
| 947 |
+
predict_net.external_input.remove(name)
|
| 948 |
+
|
| 949 |
+
return predict_net, params
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
|
| 953 |
+
"""
|
| 954 |
+
In-place fuse extra copy ops between cpu/gpu for the following case:
|
| 955 |
+
a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
|
| 956 |
+
-CopyBToA> c2 -NextOp2-> d2
|
| 957 |
+
The fused network will look like:
|
| 958 |
+
a -NextOp1-> d1
|
| 959 |
+
-NextOp2-> d2
|
| 960 |
+
"""
|
| 961 |
+
|
| 962 |
+
_COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
|
| 963 |
+
|
| 964 |
+
def _fuse_once(predict_net):
|
| 965 |
+
ssa, blob_versions = core.get_ssa(predict_net)
|
| 966 |
+
consumer_map = get_consumer_map(ssa)
|
| 967 |
+
versioned_external_output = [
|
| 968 |
+
(name, blob_versions[name]) for name in predict_net.external_output
|
| 969 |
+
]
|
| 970 |
+
|
| 971 |
+
for op_id, op in enumerate(predict_net.op):
|
| 972 |
+
if op.type in _COPY_OPS:
|
| 973 |
+
fw_copy_versioned_output = ssa[op_id][1][0]
|
| 974 |
+
consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
|
| 975 |
+
reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
|
| 976 |
+
|
| 977 |
+
is_fusable = (
|
| 978 |
+
len(consumer_ids) > 0
|
| 979 |
+
and fw_copy_versioned_output not in versioned_external_output
|
| 980 |
+
and all(
|
| 981 |
+
predict_net.op[_op_id].type == reverse_op_type
|
| 982 |
+
and ssa[_op_id][1][0] not in versioned_external_output
|
| 983 |
+
for _op_id in consumer_ids
|
| 984 |
+
)
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
if is_fusable:
|
| 988 |
+
for rv_copy_op_id in consumer_ids:
|
| 989 |
+
# making each NextOp uses "a" directly and removing Copy ops
|
| 990 |
+
rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
|
| 991 |
+
next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
|
| 992 |
+
predict_net.op[next_op_id].input[inp_id] = op.input[0]
|
| 993 |
+
# remove CopyOps
|
| 994 |
+
new_ops = [
|
| 995 |
+
op
|
| 996 |
+
for i, op in enumerate(predict_net.op)
|
| 997 |
+
if i != op_id and i not in consumer_ids
|
| 998 |
+
]
|
| 999 |
+
del predict_net.op[:]
|
| 1000 |
+
predict_net.op.extend(new_ops)
|
| 1001 |
+
return True
|
| 1002 |
+
|
| 1003 |
+
return False
|
| 1004 |
+
|
| 1005 |
+
# _fuse_once returns False is nothing can be fused
|
| 1006 |
+
while _fuse_once(predict_net):
|
| 1007 |
+
pass
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
|
| 1011 |
+
""" remove ops if its output is not used or not in external_output """
|
| 1012 |
+
ssa, versions = core.get_ssa(net_def)
|
| 1013 |
+
versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
|
| 1014 |
+
consumer_map = get_consumer_map(ssa)
|
| 1015 |
+
removed_op_ids = set()
|
| 1016 |
+
|
| 1017 |
+
def _is_dead_end(versioned_blob):
|
| 1018 |
+
return not (
|
| 1019 |
+
versioned_blob in versioned_external_output
|
| 1020 |
+
or (
|
| 1021 |
+
len(consumer_map[versioned_blob]) > 0
|
| 1022 |
+
and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
|
| 1023 |
+
)
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
for i, ssa_i in reversed(list(enumerate(ssa))):
|
| 1027 |
+
versioned_outputs = ssa_i[1]
|
| 1028 |
+
if all(_is_dead_end(outp) for outp in versioned_outputs):
|
| 1029 |
+
removed_op_ids.add(i)
|
| 1030 |
+
|
| 1031 |
+
# simply removing those deadend ops should have no effect to external_output
|
| 1032 |
+
new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
|
| 1033 |
+
del net_def.op[:]
|
| 1034 |
+
net_def.op.extend(new_ops)
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <torch/types.h>
|
| 4 |
+
|
| 5 |
+
namespace detectron2 {
|
| 6 |
+
|
| 7 |
+
at::Tensor box_iou_rotated_cpu(
|
| 8 |
+
const at::Tensor& boxes1,
|
| 9 |
+
const at::Tensor& boxes2);
|
| 10 |
+
|
| 11 |
+
#ifdef WITH_CUDA
|
| 12 |
+
at::Tensor box_iou_rotated_cuda(
|
| 13 |
+
const at::Tensor& boxes1,
|
| 14 |
+
const at::Tensor& boxes2);
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
// Interface for Python
|
| 18 |
+
// inline is needed to prevent multiple function definitions when this header is
|
| 19 |
+
// included by different cpps
|
| 20 |
+
inline at::Tensor box_iou_rotated(
|
| 21 |
+
const at::Tensor& boxes1,
|
| 22 |
+
const at::Tensor& boxes2) {
|
| 23 |
+
assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
|
| 24 |
+
if (boxes1.device().is_cuda()) {
|
| 25 |
+
#ifdef WITH_CUDA
|
| 26 |
+
return box_iou_rotated_cuda(boxes1.contiguous(), boxes2.contiguous());
|
| 27 |
+
#else
|
| 28 |
+
AT_ERROR("Not compiled with GPU support");
|
| 29 |
+
#endif
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return box_iou_rotated_cpu(boxes1.contiguous(), boxes2.contiguous());
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
} // namespace detectron2
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
#include "box_iou_rotated.h"
|
| 3 |
+
#include "box_iou_rotated_utils.h"
|
| 4 |
+
|
| 5 |
+
namespace detectron2 {
|
| 6 |
+
|
| 7 |
+
template <typename T>
|
| 8 |
+
void box_iou_rotated_cpu_kernel(
|
| 9 |
+
const at::Tensor& boxes1,
|
| 10 |
+
const at::Tensor& boxes2,
|
| 11 |
+
at::Tensor& ious) {
|
| 12 |
+
auto num_boxes1 = boxes1.size(0);
|
| 13 |
+
auto num_boxes2 = boxes2.size(0);
|
| 14 |
+
|
| 15 |
+
for (int i = 0; i < num_boxes1; i++) {
|
| 16 |
+
for (int j = 0; j < num_boxes2; j++) {
|
| 17 |
+
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
|
| 18 |
+
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>());
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
at::Tensor box_iou_rotated_cpu(
|
| 24 |
+
// input must be contiguous:
|
| 25 |
+
const at::Tensor& boxes1,
|
| 26 |
+
const at::Tensor& boxes2) {
|
| 27 |
+
auto num_boxes1 = boxes1.size(0);
|
| 28 |
+
auto num_boxes2 = boxes2.size(0);
|
| 29 |
+
at::Tensor ious =
|
| 30 |
+
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
| 31 |
+
|
| 32 |
+
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious);
|
| 33 |
+
|
| 34 |
+
// reshape from 1d array to 2d array
|
| 35 |
+
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
| 36 |
+
return ious.reshape(shape);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
} // namespace detectron2
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 5 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 6 |
+
#include "box_iou_rotated_utils.h"
|
| 7 |
+
|
| 8 |
+
namespace detectron2 {
|
| 9 |
+
|
| 10 |
+
// 2D block with 32 * 16 = 512 threads per block
|
| 11 |
+
const int BLOCK_DIM_X = 32;
|
| 12 |
+
const int BLOCK_DIM_Y = 16;
|
| 13 |
+
|
| 14 |
+
template <typename T>
|
| 15 |
+
__global__ void box_iou_rotated_cuda_kernel(
|
| 16 |
+
const int n_boxes1,
|
| 17 |
+
const int n_boxes2,
|
| 18 |
+
const T* dev_boxes1,
|
| 19 |
+
const T* dev_boxes2,
|
| 20 |
+
T* dev_ious) {
|
| 21 |
+
const int row_start = blockIdx.x * blockDim.x;
|
| 22 |
+
const int col_start = blockIdx.y * blockDim.y;
|
| 23 |
+
|
| 24 |
+
const int row_size = min(n_boxes1 - row_start, blockDim.x);
|
| 25 |
+
const int col_size = min(n_boxes2 - col_start, blockDim.y);
|
| 26 |
+
|
| 27 |
+
__shared__ float block_boxes1[BLOCK_DIM_X * 5];
|
| 28 |
+
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];
|
| 29 |
+
|
| 30 |
+
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
|
| 31 |
+
if (threadIdx.x < row_size && threadIdx.y == 0) {
|
| 32 |
+
block_boxes1[threadIdx.x * 5 + 0] =
|
| 33 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
|
| 34 |
+
block_boxes1[threadIdx.x * 5 + 1] =
|
| 35 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
|
| 36 |
+
block_boxes1[threadIdx.x * 5 + 2] =
|
| 37 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
|
| 38 |
+
block_boxes1[threadIdx.x * 5 + 3] =
|
| 39 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
|
| 40 |
+
block_boxes1[threadIdx.x * 5 + 4] =
|
| 41 |
+
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
if (threadIdx.x < col_size && threadIdx.y == 0) {
|
| 45 |
+
block_boxes2[threadIdx.x * 5 + 0] =
|
| 46 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
|
| 47 |
+
block_boxes2[threadIdx.x * 5 + 1] =
|
| 48 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
|
| 49 |
+
block_boxes2[threadIdx.x * 5 + 2] =
|
| 50 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
|
| 51 |
+
block_boxes2[threadIdx.x * 5 + 3] =
|
| 52 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
|
| 53 |
+
block_boxes2[threadIdx.x * 5 + 4] =
|
| 54 |
+
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
|
| 55 |
+
}
|
| 56 |
+
__syncthreads();
|
| 57 |
+
|
| 58 |
+
if (threadIdx.x < row_size && threadIdx.y < col_size) {
|
| 59 |
+
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
|
| 60 |
+
dev_ious[offset] = single_box_iou_rotated<T>(
|
| 61 |
+
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
at::Tensor box_iou_rotated_cuda(
|
| 66 |
+
// input must be contiguous
|
| 67 |
+
const at::Tensor& boxes1,
|
| 68 |
+
const at::Tensor& boxes2) {
|
| 69 |
+
using scalar_t = float;
|
| 70 |
+
AT_ASSERTM(
|
| 71 |
+
boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor");
|
| 72 |
+
AT_ASSERTM(
|
| 73 |
+
boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor");
|
| 74 |
+
AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
|
| 75 |
+
AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
|
| 76 |
+
at::cuda::CUDAGuard device_guard(boxes1.device());
|
| 77 |
+
|
| 78 |
+
auto num_boxes1 = boxes1.size(0);
|
| 79 |
+
auto num_boxes2 = boxes2.size(0);
|
| 80 |
+
|
| 81 |
+
at::Tensor ious =
|
| 82 |
+
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
|
| 83 |
+
|
| 84 |
+
bool transpose = false;
|
| 85 |
+
if (num_boxes1 > 0 && num_boxes2 > 0) {
|
| 86 |
+
scalar_t *data1 = boxes1.data_ptr<scalar_t>(),
|
| 87 |
+
*data2 = boxes2.data_ptr<scalar_t>();
|
| 88 |
+
|
| 89 |
+
if (num_boxes2 > 65535 * BLOCK_DIM_Y) {
|
| 90 |
+
AT_ASSERTM(
|
| 91 |
+
num_boxes1 <= 65535 * BLOCK_DIM_Y,
|
| 92 |
+
"Too many boxes for box_iou_rotated_cuda!");
|
| 93 |
+
// x dim is allowed to be large, but y dim cannot,
|
| 94 |
+
// so we transpose the two to avoid "invalid configuration argument"
|
| 95 |
+
// error. We assume one of them is small. Otherwise the result is hard to
|
| 96 |
+
// fit in memory anyway.
|
| 97 |
+
std::swap(num_boxes1, num_boxes2);
|
| 98 |
+
std::swap(data1, data2);
|
| 99 |
+
transpose = true;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
const int blocks_x =
|
| 103 |
+
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes1), BLOCK_DIM_X);
|
| 104 |
+
const int blocks_y =
|
| 105 |
+
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes2), BLOCK_DIM_Y);
|
| 106 |
+
|
| 107 |
+
dim3 blocks(blocks_x, blocks_y);
|
| 108 |
+
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
|
| 109 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 110 |
+
|
| 111 |
+
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
| 112 |
+
num_boxes1,
|
| 113 |
+
num_boxes2,
|
| 114 |
+
data1,
|
| 115 |
+
data2,
|
| 116 |
+
(scalar_t*)ious.data_ptr<scalar_t>());
|
| 117 |
+
|
| 118 |
+
AT_CUDA_CHECK(cudaGetLastError());
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// reshape from 1d array to 2d array
|
| 122 |
+
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
|
| 123 |
+
if (transpose) {
|
| 124 |
+
return ious.view(shape).t();
|
| 125 |
+
} else {
|
| 126 |
+
return ious.view(shape);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
} // namespace detectron2
|
Leffa/preprocess/humanparsing/mhp_extension/detectron2/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <cassert>
|
| 5 |
+
#include <cmath>
|
| 6 |
+
|
| 7 |
+
#ifdef __CUDACC__
|
| 8 |
+
// Designates functions callable from the host (CPU) and the device (GPU)
|
| 9 |
+
#define HOST_DEVICE __host__ __device__
|
| 10 |
+
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
|
| 11 |
+
#else
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#define HOST_DEVICE
|
| 14 |
+
#define HOST_DEVICE_INLINE HOST_DEVICE inline
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
namespace detectron2 {
|
| 18 |
+
|
| 19 |
+
namespace {
|
| 20 |
+
|
| 21 |
+
template <typename T>
|
| 22 |
+
struct RotatedBox {
|
| 23 |
+
T x_ctr, y_ctr, w, h, a;
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
template <typename T>
|
| 27 |
+
struct Point {
|
| 28 |
+
T x, y;
|
| 29 |
+
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
|
| 30 |
+
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
|
| 31 |
+
return Point(x + p.x, y + p.y);
|
| 32 |
+
}
|
| 33 |
+
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
|
| 34 |
+
x += p.x;
|
| 35 |
+
y += p.y;
|
| 36 |
+
return *this;
|
| 37 |
+
}
|
| 38 |
+
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
|
| 39 |
+
return Point(x - p.x, y - p.y);
|
| 40 |
+
}
|
| 41 |
+
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
|
| 42 |
+
return Point(x * coeff, y * coeff);
|
| 43 |
+
}
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
template <typename T>
|
| 47 |
+
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
|
| 48 |
+
return A.x * B.x + A.y * B.y;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// R: result type. can be different from input type
|
| 52 |
+
template <typename T, typename R = T>
|
| 53 |
+
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
|
| 54 |
+
return static_cast<R>(A.x) * static_cast<R>(B.y) -
|
| 55 |
+
static_cast<R>(B.x) * static_cast<R>(A.y);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename T>
|
| 59 |
+
HOST_DEVICE_INLINE void get_rotated_vertices(
|
| 60 |
+
const RotatedBox<T>& box,
|
| 61 |
+
Point<T> (&pts)[4]) {
|
| 62 |
+
// M_PI / 180. == 0.01745329251
|
| 63 |
+
double theta = box.a * 0.01745329251;
|
| 64 |
+
T cosTheta2 = (T)cos(theta) * 0.5f;
|
| 65 |
+
T sinTheta2 = (T)sin(theta) * 0.5f;
|
| 66 |
+
|
| 67 |
+
// y: top --> down; x: left --> right
|
| 68 |
+
pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
|
| 69 |
+
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
|
| 70 |
+
pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
|
| 71 |
+
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
|
| 72 |
+
pts[2].x = 2 * box.x_ctr - pts[0].x;
|
| 73 |
+
pts[2].y = 2 * box.y_ctr - pts[0].y;
|
| 74 |
+
pts[3].x = 2 * box.x_ctr - pts[1].x;
|
| 75 |
+
pts[3].y = 2 * box.y_ctr - pts[1].y;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename T>
|
| 79 |
+
HOST_DEVICE_INLINE int get_intersection_points(
|
| 80 |
+
const Point<T> (&pts1)[4],
|
| 81 |
+
const Point<T> (&pts2)[4],
|
| 82 |
+
Point<T> (&intersections)[24]) {
|
| 83 |
+
// Line vector
|
| 84 |
+
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
|
| 85 |
+
Point<T> vec1[4], vec2[4];
|
| 86 |
+
for (int i = 0; i < 4; i++) {
|
| 87 |
+
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
|
| 88 |
+
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// Line test - test all line combos for intersection
|
| 92 |
+
int num = 0; // number of intersections
|
| 93 |
+
for (int i = 0; i < 4; i++) {
|
| 94 |
+
for (int j = 0; j < 4; j++) {
|
| 95 |
+
// Solve for 2x2 Ax=b
|
| 96 |
+
T det = cross_2d<T>(vec2[j], vec1[i]);
|
| 97 |
+
|
| 98 |
+
// This takes care of parallel lines
|
| 99 |
+
if (fabs(det) <= 1e-14) {
|
| 100 |
+
continue;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
auto vec12 = pts2[j] - pts1[i];
|
| 104 |
+
|
| 105 |
+
T t1 = cross_2d<T>(vec2[j], vec12) / det;
|
| 106 |
+
T t2 = cross_2d<T>(vec1[i], vec12) / det;
|
| 107 |
+
|
| 108 |
+
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
|
| 109 |
+
intersections[num++] = pts1[i] + vec1[i] * t1;
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// Check for vertices of rect1 inside rect2
|
| 115 |
+
{
|
| 116 |
+
const auto& AB = vec2[0];
|
| 117 |
+
const auto& DA = vec2[3];
|
| 118 |
+
auto ABdotAB = dot_2d<T>(AB, AB);
|
| 119 |
+
auto ADdotAD = dot_2d<T>(DA, DA);
|
| 120 |
+
for (int i = 0; i < 4; i++) {
|
| 121 |
+
// assume ABCD is the rectangle, and P is the point to be judged
|
| 122 |
+
// P is inside ABCD iff. P's projection on AB lies within AB
|
| 123 |
+
// and P's projection on AD lies within AD
|
| 124 |
+
|
| 125 |
+
auto AP = pts1[i] - pts2[0];
|
| 126 |
+
|
| 127 |
+
auto APdotAB = dot_2d<T>(AP, AB);
|
| 128 |
+
auto APdotAD = -dot_2d<T>(AP, DA);
|
| 129 |
+
|
| 130 |
+
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
| 131 |
+
(APdotAD <= ADdotAD)) {
|
| 132 |
+
intersections[num++] = pts1[i];
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// Reverse the check - check for vertices of rect2 inside rect1
|
| 138 |
+
{
|
| 139 |
+
const auto& AB = vec1[0];
|
| 140 |
+
const auto& DA = vec1[3];
|
| 141 |
+
auto ABdotAB = dot_2d<T>(AB, AB);
|
| 142 |
+
auto ADdotAD = dot_2d<T>(DA, DA);
|
| 143 |
+
for (int i = 0; i < 4; i++) {
|
| 144 |
+
auto AP = pts2[i] - pts1[0];
|
| 145 |
+
|
| 146 |
+
auto APdotAB = dot_2d<T>(AP, AB);
|
| 147 |
+
auto APdotAD = -dot_2d<T>(AP, DA);
|
| 148 |
+
|
| 149 |
+
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
|
| 150 |
+
(APdotAD <= ADdotAD)) {
|
| 151 |
+
intersections[num++] = pts2[i];
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return num;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
template <typename T>
|
| 160 |
+
HOST_DEVICE_INLINE int convex_hull_graham(
|
| 161 |
+
const Point<T> (&p)[24],
|
| 162 |
+
const int& num_in,
|
| 163 |
+
Point<T> (&q)[24],
|
| 164 |
+
bool shift_to_zero = false) {
|
| 165 |
+
assert(num_in >= 2);
|
| 166 |
+
|
| 167 |
+
// Step 1:
|
| 168 |
+
// Find point with minimum y
|
| 169 |
+
// if more than 1 points have the same minimum y,
|
| 170 |
+
// pick the one with the minimum x.
|
| 171 |
+
int t = 0;
|
| 172 |
+
for (int i = 1; i < num_in; i++) {
|
| 173 |
+
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
|
| 174 |
+
t = i;
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
auto& start = p[t]; // starting point
|
| 178 |
+
|
| 179 |
+
// Step 2:
|
| 180 |
+
// Subtract starting point from every points (for sorting in the next step)
|
| 181 |
+
for (int i = 0; i < num_in; i++) {
|
| 182 |
+
q[i] = p[i] - start;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Swap the starting point to position 0
|
| 186 |
+
auto tmp = q[0];
|
| 187 |
+
q[0] = q[t];
|
| 188 |
+
q[t] = tmp;
|
| 189 |
+
|
| 190 |
+
// Step 3:
|
| 191 |
+
// Sort point 1 ~ num_in according to their relative cross-product values
|
| 192 |
+
// (essentially sorting according to angles)
|
| 193 |
+
// If the angles are the same, sort according to their distance to origin
|
| 194 |
+
T dist[24];
|
| 195 |
+
#ifdef __CUDACC__
|
| 196 |
+
// compute distance to origin before sort, and sort them together with the
|
| 197 |
+
// points
|
| 198 |
+
for (int i = 0; i < num_in; i++) {
|
| 199 |
+
dist[i] = dot_2d<T>(q[i], q[i]);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// CUDA version
|
| 203 |
+
// In the future, we can potentially use thrust
|
| 204 |
+
// for sorting here to improve speed (though not guaranteed)
|
| 205 |
+
for (int i = 1; i < num_in - 1; i++) {
|
| 206 |
+
for (int j = i + 1; j < num_in; j++) {
|
| 207 |
+
T crossProduct = cross_2d<T>(q[i], q[j]);
|
| 208 |
+
if ((crossProduct < -1e-6) ||
|
| 209 |
+
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
|
| 210 |
+
auto q_tmp = q[i];
|
| 211 |
+
q[i] = q[j];
|
| 212 |
+
q[j] = q_tmp;
|
| 213 |
+
auto dist_tmp = dist[i];
|
| 214 |
+
dist[i] = dist[j];
|
| 215 |
+
dist[j] = dist_tmp;
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
#else
|
| 220 |
+
// CPU version
|
| 221 |
+
std::sort(
|
| 222 |
+
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
|
| 223 |
+
T temp = cross_2d<T>(A, B);
|
| 224 |
+
if (fabs(temp) < 1e-6) {
|
| 225 |
+
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
|
| 226 |
+
} else {
|
| 227 |
+
return temp > 0;
|
| 228 |
+
}
|
| 229 |
+
});
|
| 230 |
+
// compute distance to origin after sort, since the points are now different.
|
| 231 |
+
for (int i = 0; i < num_in; i++) {
|
| 232 |
+
dist[i] = dot_2d<T>(q[i], q[i]);
|
| 233 |
+
}
|
| 234 |
+
#endif
|
| 235 |
+
|
| 236 |
+
// Step 4:
|
| 237 |
+
// Make sure there are at least 2 points (that don't overlap with each other)
|
| 238 |
+
// in the stack
|
| 239 |
+
int k; // index of the non-overlapped second point
|
| 240 |
+
for (k = 1; k < num_in; k++) {
|
| 241 |
+
if (dist[k] > 1e-8) {
|
| 242 |
+
break;
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
if (k == num_in) {
|
| 246 |
+
// We reach the end, which means the convex hull is just one point
|
| 247 |
+
q[0] = p[t];
|
| 248 |
+
return 1;
|
| 249 |
+
}
|
| 250 |
+
q[1] = q[k];
|
| 251 |
+
int m = 2; // 2 points in the stack
|
| 252 |
+
// Step 5:
|
| 253 |
+
// Finally we can start the scanning process.
|
| 254 |
+
// When a non-convex relationship between the 3 points is found
|
| 255 |
+
// (either concave shape or duplicated points),
|
| 256 |
+
// we pop the previous point from the stack
|
| 257 |
+
// until the 3-point relationship is convex again, or
|
| 258 |
+
// until the stack only contains two points
|
| 259 |
+
for (int i = k + 1; i < num_in; i++) {
|
| 260 |
+
while (m > 1) {
|
| 261 |
+
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
|
| 262 |
+
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
|
| 263 |
+
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
|
| 264 |
+
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
|
| 265 |
+
// round to nearest floating point).
|
| 266 |
+
if (q1.x * q2.y >= q2.x * q1.y)
|
| 267 |
+
m--;
|
| 268 |
+
else
|
| 269 |
+
break;
|
| 270 |
+
}
|
| 271 |
+
// Using double also helps, but float can solve the issue for now.
|
| 272 |
+
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
|
| 273 |
+
// >= 0) {
|
| 274 |
+
// m--;
|
| 275 |
+
// }
|
| 276 |
+
q[m++] = q[i];
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
// Step 6 (Optional):
|
| 280 |
+
// In general sense we need the original coordinates, so we
|
| 281 |
+
// need to shift the points back (reverting Step 2)
|
| 282 |
+
// But if we're only interested in getting the area/perimeter of the shape
|
| 283 |
+
// We can simply return.
|
| 284 |
+
if (!shift_to_zero) {
|
| 285 |
+
for (int i = 0; i < m; i++) {
|
| 286 |
+
q[i] += start;
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
return m;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
template <typename T>
|
| 294 |
+
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
|
| 295 |
+
if (m <= 2) {
|
| 296 |
+
return 0;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
T area = 0;
|
| 300 |
+
for (int i = 1; i < m - 1; i++) {
|
| 301 |
+
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
return area / 2.0;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
template <typename T>
|
| 308 |
+
HOST_DEVICE_INLINE T rotated_boxes_intersection(
|
| 309 |
+
const RotatedBox<T>& box1,
|
| 310 |
+
const RotatedBox<T>& box2) {
|
| 311 |
+
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
|
| 312 |
+
// from rotated_rect_intersection_pts
|
| 313 |
+
Point<T> intersectPts[24], orderedPts[24];
|
| 314 |
+
|
| 315 |
+
Point<T> pts1[4];
|
| 316 |
+
Point<T> pts2[4];
|
| 317 |
+
get_rotated_vertices<T>(box1, pts1);
|
| 318 |
+
get_rotated_vertices<T>(box2, pts2);
|
| 319 |
+
|
| 320 |
+
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
|
| 321 |
+
|
| 322 |
+
if (num <= 2) {
|
| 323 |
+
return 0.0;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
// Convex Hull to order the intersection points in clockwise order and find
|
| 327 |
+
// the contour area.
|
| 328 |
+
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
|
| 329 |
+
return polygon_area<T>(orderedPts, num_convex);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
} // namespace
|
| 333 |
+
|
| 334 |
+
template <typename T>
|
| 335 |
+
HOST_DEVICE_INLINE T
|
| 336 |
+
single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
|
| 337 |
+
// shift center to the middle point to achieve higher precision in result
|
| 338 |
+
RotatedBox<T> box1, box2;
|
| 339 |
+
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
|
| 340 |
+
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
|
| 341 |
+
box1.x_ctr = box1_raw[0] - center_shift_x;
|
| 342 |
+
box1.y_ctr = box1_raw[1] - center_shift_y;
|
| 343 |
+
box1.w = box1_raw[2];
|
| 344 |
+
box1.h = box1_raw[3];
|
| 345 |
+
box1.a = box1_raw[4];
|
| 346 |
+
box2.x_ctr = box2_raw[0] - center_shift_x;
|
| 347 |
+
box2.y_ctr = box2_raw[1] - center_shift_y;
|
| 348 |
+
box2.w = box2_raw[2];
|
| 349 |
+
box2.h = box2_raw[3];
|
| 350 |
+
box2.a = box2_raw[4];
|
| 351 |
+
|
| 352 |
+
T area1 = box1.w * box1.h;
|
| 353 |
+
T area2 = box2.w * box2.h;
|
| 354 |
+
if (area1 < 1e-14 || area2 < 1e-14) {
|
| 355 |
+
return 0.f;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
T intersection = rotated_boxes_intersection<T>(box1, box2);
|
| 359 |
+
T iou = intersection / (area1 + area2 - intersection);
|
| 360 |
+
return iou;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
} // namespace detectron2
|