Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +14 -0
- perception_models/.gitignore +10 -0
- perception_models/CODE_OF_CONDUCT.md +80 -0
- perception_models/CONTRIBUTING.md +31 -0
- perception_models/LEGRAD_PE_USAGE.md +72 -0
- perception_models/LICENSE.PE +201 -0
- perception_models/LICENSE.PLM +124 -0
- perception_models/README.md +408 -0
- perception_models/__pycache__/legrad_pe_audio.cpython-310.pyc +0 -0
- perception_models/__pycache__/legrad_pe_audio.cpython-313.pyc +0 -0
- perception_models/__pycache__/legrad_pe_image.cpython-312.pyc +0 -0
- perception_models/__pycache__/legrad_pe_image.cpython-313.pyc +0 -0
- perception_models/apps/detection/DETA_pe/README.md +53 -0
- perception_models/apps/detection/DETA_pe/datasets/__init__.py +37 -0
- perception_models/apps/detection/DETA_pe/datasets/coco.py +345 -0
- perception_models/apps/detection/DETA_pe/datasets/coco_eval.py +265 -0
- perception_models/apps/detection/DETA_pe/datasets/coco_panoptic.py +107 -0
- perception_models/apps/detection/DETA_pe/datasets/data_prefetcher.py +70 -0
- perception_models/apps/detection/DETA_pe/datasets/objects365.py +54 -0
- perception_models/apps/detection/DETA_pe/datasets/panoptic_eval.py +52 -0
- perception_models/apps/detection/DETA_pe/datasets/samplers.py +348 -0
- perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/__init__.py +7 -0
- perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/coco.py +84 -0
- perception_models/apps/detection/DETA_pe/datasets/transforms.py +327 -0
- perception_models/apps/detection/DETA_pe/engine.py +303 -0
- perception_models/apps/detection/DETA_pe/engine_tta.py +239 -0
- perception_models/apps/detection/DETA_pe/main.py +754 -0
- perception_models/apps/detection/DETA_pe/models/__init__.py +15 -0
- perception_models/apps/detection/DETA_pe/models/assigner.py +378 -0
- perception_models/apps/detection/DETA_pe/models/backbone.py +235 -0
- perception_models/apps/detection/DETA_pe/models/deformable_detr.py +776 -0
- perception_models/apps/detection/DETA_pe/models/deformable_transformer.py +451 -0
- perception_models/apps/detection/DETA_pe/models/matcher.py +102 -0
- perception_models/apps/detection/DETA_pe/models/ops/functions/__init__.py +9 -0
- perception_models/apps/detection/DETA_pe/models/ops/functions/ms_deform_attn_func.py +106 -0
- perception_models/apps/detection/DETA_pe/models/ops/make.sh +10 -0
- perception_models/apps/detection/DETA_pe/models/ops/modules/__init__.py +9 -0
- perception_models/apps/detection/DETA_pe/models/ops/modules/ms_deform_attn.py +161 -0
- perception_models/apps/detection/DETA_pe/models/ops/setup.py +71 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/ms_deform_attn.h +62 -0
- perception_models/apps/detection/DETA_pe/models/ops/src/vision.cpp +16 -0
- perception_models/apps/detection/DETA_pe/models/ops/test.py +89 -0
- perception_models/apps/detection/DETA_pe/models/pev1.py +686 -0
- perception_models/apps/detection/DETA_pe/models/position_encoding.py +97 -0
- perception_models/apps/detection/DETA_pe/models/segmentation.py +369 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
perception_models/apps/pe/docs/assets/dog.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
perception_models/apps/pe/docs/assets/dog.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
perception_models/apps/pe/docs/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
perception_models/apps/pe/docs/assets/office.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
perception_models/apps/pe/docs/assets/pikachu.webp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
perception_models/apps/pe/docs/assets/shark.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
perception_models/apps/pe/docs/assets/spatial_correspondence.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
perception_models/apps/pe/docs/assets/spatial_features.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
perception_models/apps/pe/docs/assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
perception_models/apps/pe/docs/assets/train.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
perception_models/apps/pe/docs/assets/train.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
perception_models/apps/plm/docs/plm_main_fig.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
perception_models/core/tests/Rock-climbing-Canada-1920x1147.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
perception_models/core/tests/selfie_cathedral_peak.jpg filter=lfs diff=lfs merge=lfs -text
|
perception_models/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
.vscode
|
| 3 |
+
*.ipynb
|
| 4 |
+
slurm-*.out
|
| 5 |
+
wandb
|
| 6 |
+
data/*
|
| 7 |
+
data-gym-cache/*
|
| 8 |
+
torchinductor_*/*
|
| 9 |
+
tmp*/*
|
| 10 |
+
apps/plm/dummy_datasets
|
perception_models/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
perception_models/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to Perception Models
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to mae, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
perception_models/LEGRAD_PE_USAGE.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LeGrad + PE Perception Encoder Notebook Usage
|
| 2 |
+
|
| 3 |
+
This repository includes a notebook `legrad_perception_encoder.ipynb` that demonstrates how to run **LeGrad** explanations on the PE CoCa-style vision encoder.
|
| 4 |
+
|
| 5 |
+
## 1. Environment and installation
|
| 6 |
+
|
| 7 |
+
- **Install this repo** (from the repo root):
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
pip install -e .
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
- **Install LeGrad** (if not already installed):
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
pip install legrad
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
Make sure you have a working CUDA‑enabled PyTorch environment.
|
| 20 |
+
|
| 21 |
+
## 2. Open the notebook
|
| 22 |
+
|
| 23 |
+
From the repo root:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
cd xai/perception_models
|
| 27 |
+
jupyter lab legrad_perception_encoder.ipynb
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## 3. What the notebook does
|
| 31 |
+
|
| 32 |
+
The notebook shows how to:
|
| 33 |
+
|
| 34 |
+
1. Load a PE CoCa‑style vision encoder:
|
| 35 |
+
- Uses `pe.CLIP.from_config("PE-Core-B16-224", pretrained=True)` and moves the model to CUDA.
|
| 36 |
+
2. Wrap the model with LeGrad:
|
| 37 |
+
- `LeWrapper` lives in `core/legrad_pe.py`.
|
| 38 |
+
- It hooks PE residual blocks and attention pooling so gradients can be used to build visual explanations.
|
| 39 |
+
3. Prepare inputs:
|
| 40 |
+
- Build an image transform with `transforms.get_image_transform(model.image_size)`.
|
| 41 |
+
- Tokenize text prompts with `transforms.get_text_tokenizer(model.context_length)`.
|
| 42 |
+
4. Run LeGrad:
|
| 43 |
+
- **Multi‑layer explanation**:
|
| 44 |
+
- `heatmap = wrapped_model.compute_legrad_coca(text_emb, image=image_tensor)`
|
| 45 |
+
- **Single‑layer explanation**:
|
| 46 |
+
- `heatmap = wrapped_model.compute_legrad_coca_one_layer(text_emb, image=image_tensor, layer_idx=-1)`
|
| 47 |
+
5. Visualize:
|
| 48 |
+
- Convert the `heatmap` to numpy and use `legrad.visualize` (or standard plotting) to overlay it on the image.
|
| 49 |
+
|
| 50 |
+
## 4. Minimal code sketch (inside the notebook)
|
| 51 |
+
|
| 52 |
+
The core usage pattern is:
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
import core.vision_encoder.pe as pe
|
| 56 |
+
import core.vision_encoder.transforms as transforms
|
| 57 |
+
from core.legrad_pe import LeWrapper
|
| 58 |
+
|
| 59 |
+
model = pe.CLIP.from_config("PE-Core-B16-224", pretrained=True).cuda()
|
| 60 |
+
preprocess = transforms.get_image_transform(model.image_size)
|
| 61 |
+
tokenizer = transforms.get_text_tokenizer(model.context_length)
|
| 62 |
+
|
| 63 |
+
wrapped_model = LeWrapper(model, layer_index=-2)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
You can then:
|
| 67 |
+
|
| 68 |
+
- Preprocess an input image with `preprocess`,
|
| 69 |
+
- Tokenize prompts with `tokenizer`,
|
| 70 |
+
- Encode text/image, and
|
| 71 |
+
- Call one of the `compute_legrad_*` methods to obtain a heatmap for visualization.
|
| 72 |
+
|
perception_models/LICENSE.PE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
perception_models/LICENSE.PLM
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FAIR Noncommercial Research License
|
| 2 |
+
Last Updated: 17 April 2025
|
| 3 |
+
|
| 4 |
+
“Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
|
| 5 |
+
|
| 6 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 10 |
+
Research Materials distributed by Meta.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 17 |
+
|
| 18 |
+
“Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
|
| 19 |
+
|
| 20 |
+
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 21 |
+
|
| 22 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
1. License Rights and Redistribution.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
|
| 29 |
+
|
| 30 |
+
b. Redistribution and Use.
|
| 31 |
+
i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
|
| 41 |
+
2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 45 |
+
|
| 46 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 47 |
+
|
| 48 |
+
5. Intellectual Property.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 52 |
+
|
| 53 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
|
| 54 |
+
|
| 55 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
|
| 56 |
+
|
| 57 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [INSERT URL]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
FAIR Acceptable Use Policy
|
| 64 |
+
|
| 65 |
+
The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
|
| 66 |
+
|
| 67 |
+
As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
|
| 68 |
+
|
| 69 |
+
Prohibited Uses
|
| 70 |
+
|
| 71 |
+
You agree you will not use, or allow others to use, Research Materials to:
|
| 72 |
+
|
| 73 |
+
Violate the law or others’ rights, including to:
|
| 74 |
+
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
| 75 |
+
Violence or terrorism
|
| 76 |
+
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
| 77 |
+
Human trafficking, exploitation, and sexual violence
|
| 78 |
+
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
| 79 |
+
Sexual solicitation
|
| 80 |
+
Any other criminal activity
|
| 81 |
+
|
| 82 |
+
Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
| 83 |
+
|
| 84 |
+
Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
| 85 |
+
|
| 86 |
+
Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
| 87 |
+
|
| 88 |
+
Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
| 89 |
+
|
| 90 |
+
Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
|
| 91 |
+
|
| 92 |
+
Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
| 93 |
+
|
| 94 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
|
| 95 |
+
|
| 96 |
+
Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
| 97 |
+
|
| 98 |
+
Guns and illegal weapons (including weapon development)
|
| 99 |
+
|
| 100 |
+
Illegal drugs and regulated/controlled substances
|
| 101 |
+
|
| 102 |
+
Operation of critical infrastructure, transportation technologies, or heavy machinery
|
| 103 |
+
|
| 104 |
+
Self-harm or harm to others, including suicide, cutting, and eating disorders
|
| 105 |
+
|
| 106 |
+
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
| 107 |
+
|
| 108 |
+
3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
|
| 109 |
+
|
| 110 |
+
Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
| 111 |
+
|
| 112 |
+
Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
| 113 |
+
|
| 114 |
+
Generating, promoting, or further distributing spam
|
| 115 |
+
|
| 116 |
+
Impersonating another individual without consent, authorization, or legal right
|
| 117 |
+
|
| 118 |
+
Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated
|
| 119 |
+
|
| 120 |
+
Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
| 121 |
+
|
| 122 |
+
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
|
| 123 |
+
|
| 124 |
+
Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
|
perception_models/README.md
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perception Models: Powerful Models for Image, Video, and Audio Perception
|
| 2 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 3 |
+
|
| 4 |
+
This repo is the home to the state-of-the-art for image and video _perception_: [**Perception Encoder (PE)**](https://arxiv.org/abs/2504.13181) for image, video, [audio](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/) encoding, and [**Perception Language Model (PLM)**](https://arxiv.org/abs/2504.13180) for decoding.
|
| 5 |
+
|
| 6 |
+
> [!TIP]
|
| 7 |
+
> Click to Navigate!
|
| 8 |
+
>
|
| 9 |
+
> [Perception Encoder and Perception Encoder Audio-Visual](#perception-encoder-pe)
|
| 10 |
+
>
|
| 11 |
+
> [Perception Language Model](#perception-language-model-plm)
|
| 12 |
+
>
|
| 13 |
+
> [Dataset Releases](#dataset-releases)
|
| 14 |
+
|
| 15 |
+
## Updates
|
| 16 |
+
* **[Dec-16-25]:** We have released the Perception Encoder Audio-Visual (PE-AV) and Perception Encoder Audio-Frame (PE-A-Frame) models: [[`Blog`](https://ai.meta.com/blog/sam-audio/)][[`paper`](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)] :fire::fire:
|
| 17 |
+
* **[Jul-14-25]:** PerceptionLM is now available in [Hugging Face transformers](https://huggingface.co/docs/transformers/main/en/model_doc/perception_lm). :fire::fire:
|
| 18 |
+
* **[Jul-11-25]:** We have release 8 new checkpoints for [Perception Encoder](apps/pe/README.md): 2x small core models (T and S), 2x tiling-tuned lang models (G and L), and 4x smaller spatial models (L, B, S, T). Give them a try! :fire::fire::fire:
|
| 19 |
+
* **[May-28-25]:** Perception Encoder has been integrated into [timm](https://github.com/huggingface/pytorch-image-models)! :fire::fire:
|
| 20 |
+
* **[Apr-18-25]:** Perception Language Model (PLM) and PLM-VideoBench are added to lmms-eval. This makes it easy to reproduce PLM results and allows you to evaluate on the PLM-VideoBench. [[`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/638)] :fire::fire:
|
| 21 |
+
* **[Apr-17-25]:** Perception Encoder (PE) and Perception Language Model (PLM) are released. [[`Blog`](https://ai.meta.com/blog/meta-fair-updates-perception-localization-reasoning)] :fire::fire:
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Perception Encoder (PE)
|
| 25 |
+
[](https://huggingface.co/datasets/facebook/PE-Video)
|
| 26 |
+
[](https://huggingface.co/collections/facebook/perception-encoder-67f977c9a65ca5895a7f6ba1)
|
| 27 |
+
[](https://ai.meta.com/research/publications/perception-encoder-the-best-visual-embeddings-are-not-at-the-output-of-the-network)
|
| 28 |
+
[](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)
|
| 29 |
+
[](https://arxiv.org/abs/2504.13181)
|
| 30 |
+
[](https://colab.research.google.com/github/facebookresearch/perception_models/blob/main/apps/pe/docs/pe_demo.ipynb)
|
| 31 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 32 |
+
|
| 33 |
+
[Perception Encoder (PE)](https://arxiv.org/abs/2504.13181) is a family of the state-of-the-art vision and audio encoders for encoding images, video, and audio: PE core outperforms SigLIP2 on image and InternVideo2 on video benchmarks; PE lang can be used to outperform QwenVL2.5 and InternVL3 on vision language modeling; and PE spatial outperforms DINOv2 on dense prediction tasks. And all of this follows the same, easily scalable contrastive pretraining. Please see [README](apps/pe/README.md) for more details.
|
| 34 |
+
|
| 35 |
+
<img src="apps/pe/docs/assets/teaser.png" style="width: 100%; margin: 0 auto; display: block;" />
|
| 36 |
+
|
| 37 |
+
### Models
|
| 38 |
+
PE has 4 types of checkpoints, each excelling in a different area of computer vision and audio understanding:
|
| 39 |
+
- [PE core](#vision-language-benchmarks): a CLIP model excels in vision-language tasks such as zero-shot image and video classification and video retrieval.
|
| 40 |
+
- [PE lang](#multimodal-llm-benchmarks): a LLM-aligned PE that powers [PLM](https://arxiv.org/abs/2504.13180) to compete at the forefront of multimodal LLM benchmarks.
|
| 41 |
+
- [PE spatial](#vision-centric-benchmarks): a spatially tuned PE that outperforms best spatial models for vision-centric tasks such as detection, depth estimation, and tracking.
|
| 42 |
+
- [PE audio-visual](#audio-visual-benchmarks): a CLIP Model that embeds audio, video, audio-video, and text into a joint embedding space.
|
| 43 |
+
|
| 44 |
+
#### Vision-Language Benchmarks
|
| 45 |
+
| | Model | Checkpoint | IN-1k | IN-v2 | IN-A | ObjectNet | COCO-T2I | Kinetics-400 | VTT-T2V
|
| 46 |
+
|:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 47 |
+
| | **T/16** 384px | [PE-Core-T16-384](https://huggingface.co/facebook/PE-Core-T16-384) | 62.1 | 54.7 | 21.1 | 43.9 | 33.0 | 41.5 | 28.8 |
|
| 48 |
+
| | **S/16** 384px | [PE-Core-S16-384](https://huggingface.co/facebook/PE-Core-S16-384) | 72.7 | 65.0 | 49.5 | 60.0 | 42.6 | 55.0 | 39.3 |
|
| 49 |
+
| | **B/16** 224px | [PE-Core-B16-224](https://huggingface.co/facebook/PE-Core-B16-224) | 78.4 | 71.7 | 62.4 | 71.9 | 50.9 | 65.6 | 47.6 |
|
| 50 |
+
| | **L/14** 336px | [PE-Core-L14-336](https://huggingface.co/facebook/PE-Core-L14-336) | 83.5 | 77.9 | 89.0 | 84.7 | 57.1 | 73.4 | 50.3 |
|
| 51 |
+
| | **G/14** 448px | [PE-Core-G14-448](https://huggingface.co/facebook/PE-Core-G14-448) | 85.4 | 80.2 | 92.6 | 88.2 | 58.1 | 76.9 | 51.2 |
|
| 52 |
+
|
| 53 |
+
#### Multimodal LLM Benchmarks
|
| 54 |
+
|
| 55 |
+
🔬 Controlled Setting:
|
| 56 |
+
| | Encoder | Checkpoint | Doc VQA (val) | InfoQA (val) | TextVQA | MVBench | PerceptionTest (val) | EgoSchema (val) |
|
| 57 |
+
|:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 58 |
+
| | **L/14** 448px | [PE-Lang-L14-448](https://huggingface.co/facebook/PE-Lang-L14-448) | 81.9 | 46.4 | 73.0 | 52.3 | 54.7 | 59.8 |
|
| 59 |
+
| | **G/14** 448px | [PE-Lang-G14-448](https://huggingface.co/facebook/PE-Lang-G14-448) | 84.4 | 48.3 | 75.2 | 52.4 | 56.0 | 62.0 |
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
🔥 SotA Setting:
|
| 63 |
+
| | Model | Encoder | Doc VQA (test) | InfoQA (test) | TextVQA | MVBench | PerceptionTest (test) | EgoSchema (test) |
|
| 64 |
+
|:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 65 |
+
| | PLM-3B | [PE-Lang-L14-448-Tiling](https://huggingface.co/facebook/PE-Lang-L14-448-Tiling)* | 93.8 | 74.6 | 84.3 | 74.7 | 79.3 | 66.9 |
|
| 66 |
+
| | PLM-8B | [PE-Lang-G14-448-Tiling](https://huggingface.co/facebook/PE-Lang-G14-448-Tiling)* | 94.6 | 80.9 | 86.5 | 77.1 | 82.7 | 68.8 |
|
| 67 |
+
|
| 68 |
+
\* These checkpoints were aligned with tiling. Use them if you use higher than 448 resolution with tiling in the LLM decoder.
|
| 69 |
+
|
| 70 |
+
#### Vision-centric Benchmarks
|
| 71 |
+
🦾 Main model:
|
| 72 |
+
| | Encoder | Checkpoint | ADE20k <br/> [Segmentation](https://github.com/open-mmlab/mmsegmentation)<br />Linear Probe mIoU | DAVIS<br /> [Tracking](https://github.com/facebookresearch/dino/blob/main/eval_video_segmentation.py) <br />Zero-Shot J&F | LVIS <br /> [Mask R-CNN](../detection/detectron2_pe/) 1024px <br /> Box / Mask mAP | COCO <br/> [DETA](../detection/DETA_pe/) 1824px <br /> Box mAP |
|
| 73 |
+
|:--:|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 74 |
+
| | **G/14** 448px | [PE-Spatial-G14-448](https://huggingface.co/facebook/PE-Spatial-G14-448) | 49.3 | 61.5 | 54.2 / 49.3 | 66.0 |
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
<div align="center">
|
| 78 |
+
<img src="apps/pe/docs/assets/spatial_correspondence.png" style="width: 80%; margin: 0 auto; padding-top: 20px; padding-bottom: 20px; display: block;" />
|
| 79 |
+
|
| 80 |
+
Visualization of PCA of non-maked visual tokens, mapped to RGB values.
|
| 81 |
+
</div>
|
| 82 |
+
|
| 83 |
+
⚗️ Distilled Models:
|
| 84 |
+
| | Encoder<br />(Distilled from G) | Checkpoint | ADE20k <br/> [Segmentation](https://github.com/open-mmlab/mmsegmentation)<br />Linear Probe mIoU | DAVIS<br /> [Tracking](https://github.com/facebookresearch/dino/blob/main/eval_video_segmentation.py) <br />Zero-Shot J&F |
|
| 85 |
+
|:--:|:---:|:---:|:---:|:---:|
|
| 86 |
+
| | **T/16** 512px | [PE-Spatial-T16-512](https://huggingface.co/facebook/PE-Spatial-T16-512) | 27.6 | 55.0 |
|
| 87 |
+
| | **S/16** 512px | [PE-Spatial-S16-512](https://huggingface.co/facebook/PE-Spatial-S16-512) | 37.5 | 57.5 |
|
| 88 |
+
| | **B/16** 512px | [PE-Spatial-B16-512](https://huggingface.co/facebook/PE-Spatial-B16-512) | 44.4 | 58.9 |
|
| 89 |
+
| | **L/14** 448px | [PE-Spatial-L14-448](https://huggingface.co/facebook/PE-Spatial-L14-448) | 48.1 | 60.6 |
|
| 90 |
+
|
| 91 |
+
See paper for comparison to other models.
|
| 92 |
+
|
| 93 |
+
#### Audio-Visual Benchmarks
|
| 94 |
+
|
| 95 |
+
| | Model | Checkpoint | Avg Retrieval | AudioCaps T→A | AudioCaps T→V | AudioCaps V→A | Clotho T→A | Valor T→A | Valor T→V | VCTK A→T | VGGSound V→A | Internal V→A |
|
| 96 |
+
|:--:|:-----:|--------------|---------------|---------------|---------------|---------------|------------|-----------|-----------|----------|---------------|---------------|
|
| 97 |
+
| 🆕 | **AV S** 16 frames | [`pe-av-small-16-frame`](https://huggingface.co/facebook/pe-av-small-16-frame) | 45.2 | 41.2 | 18.6 | 75.4 | 24.0 | 29.8 | 70.1 | 96.1 | 34.1 | 17.9 |
|
| 98 |
+
| 🆕 | **AV B** 16 frames | [`pe-av-base-16-frame`](https://huggingface.co/facebook/pe-av-base-16-frame) | 47.0 | 43.1 | 19.8 | 80.6 | 23.4 | 31.9 | 70.0 | 94.8 | 39.0 | 20.4 |
|
| 99 |
+
| 🆕 | **AV L** 16 frames | [`pe-av-large-16-frame`](https://huggingface.co/facebook/pe-av-large-16-frame) | 48.2 | 44.7 | 19.5 | 86.1 | 22.8 | 35.0 | 70.9 | 85.6 | 45.2 | 23.9 |
|
| 100 |
+
| 🆕 | **AV S** all frames | [`pe-av-small`](https://huggingface.co/facebook/pe-av-small) | 48.1 | 41.8 | 18.8 | 77.4 | 23.9 | 29.3 | 70.9 | 94.9 | 35.4 | 40.5 |
|
| 101 |
+
| 🆕 | **AV B** all frames | [`pe-av-base`](https://huggingface.co/facebook/pe-av-base) | 50.2 | 42.7 | 19.6 | 83.7 | 23.8 | 30.8 | 71.2 | 94.9 | 40.7 | 44.6 |
|
| 102 |
+
| 🆕 | **AV L** all frames | [`pe-av-large`](https://huggingface.co/facebook/pe-av-large) | 51.6 | 45.8 | 20.8 | 88.3 | 23.0 | 35.1 | 70.9 | 85.6 | 48.3 | 46.5 |
|
| 103 |
+
|
| 104 |
+
#### Audio Event Localization Benchmarks
|
| 105 |
+
|
| 106 |
+
| | Model | Checkpoint | Internal Bench (AUROC) | ASFX-SED (AUROC) | AudioSet-Strong (AUROC) | DESED (AUROC) | UrbanSED (AUROC) |
|
| 107 |
+
|:--:|:-----:|------------------|---------------------|------------------|-----------------------|-------------|-------------|
|
| 108 |
+
| 🆕 | **A-Frame S** | [`pe-a-frame-small`](https://huggingface.co/facebook/pe-a-frame-small)| 0.91 | 0.83 | 0.96 | 0.96 | 0.88 |
|
| 109 |
+
| 🆕 | **A-Frame B** | [`pe-a-frame-base`](https://huggingface.co/facebook/pe-a-frame-base)| 0.92 | 0.83 | 0.96 | 0.98 | 0.89 |
|
| 110 |
+
| 🆕 | **A-Frame L** | [`pe-a-frame-large`](https://huggingface.co/facebook/pe-a-frame-large)| 0.91 | 0.83 | 0.96 | 0.97 | 0.89 |
|
| 111 |
+
|
| 112 |
+
### Getting Started with PE
|
| 113 |
+
You can get started with the following example for image and text feature extraction or use our [Colab Demo](https://colab.research.google.com/github/facebookresearch/perception_models/blob/main/apps/pe/docs/pe_demo.ipynb)
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
import torch
|
| 117 |
+
from PIL import Image
|
| 118 |
+
import core.vision_encoder.pe as pe
|
| 119 |
+
import core.vision_encoder.transforms as transforms
|
| 120 |
+
|
| 121 |
+
print("CLIP configs:", pe.CLIP.available_configs())
|
| 122 |
+
# CLIP configs: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224', 'PE-Core-S16-384', 'PE-Core-T16-384']
|
| 123 |
+
|
| 124 |
+
model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True) # Downloads from HF
|
| 125 |
+
model = model.cuda()
|
| 126 |
+
|
| 127 |
+
preprocess = transforms.get_image_transform(model.image_size)
|
| 128 |
+
tokenizer = transforms.get_text_tokenizer(model.context_length)
|
| 129 |
+
|
| 130 |
+
image = preprocess(Image.open("docs/assets/cat.png")).unsqueeze(0).cuda()
|
| 131 |
+
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
|
| 132 |
+
|
| 133 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
| 134 |
+
image_features, text_features, logit_scale = model(image, text)
|
| 135 |
+
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
|
| 136 |
+
|
| 137 |
+
print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
> [!TIP]
|
| 141 |
+
> See [`apps/pe/README.md`](apps/pe/README.md) for details and how to get started!
|
| 142 |
+
|
| 143 |
+
### Getting Started with PE-AV
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
import os
|
| 147 |
+
from core.audio_visual_encoder import PEAudioVisual, PEAudioVisualTransform
|
| 148 |
+
import torch
|
| 149 |
+
|
| 150 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 151 |
+
model = PEAudioVisual.from_config("pe-av-large", pretrained=True).to(device)
|
| 152 |
+
transform = PEAudioVisualTransform.from_config("pe-av-large")
|
| 153 |
+
|
| 154 |
+
video_files = ["assets/train.mp4", "assets/office.mp4"]
|
| 155 |
+
descriptions = [
|
| 156 |
+
"A person talking with sirens and a train in the background",
|
| 157 |
+
"Two people talking in an office, with sounds of workers typing on a keyboard"
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
def embed(videos=None, audio=None, text=None):
|
| 161 |
+
inputs = transform(videos=videos, audio=audio, text=text)
|
| 162 |
+
inputs = inputs.to(device)
|
| 163 |
+
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
|
| 164 |
+
return model(**inputs)
|
| 165 |
+
|
| 166 |
+
vt_outputs = embed(videos=video_files, text=descriptions)
|
| 167 |
+
avt_outputs = embed(videos=video_files, audio=video_files, text=descriptions)
|
| 168 |
+
at_outputs = embed(audio=video_files, text=descriptions)
|
| 169 |
+
|
| 170 |
+
# Compute dot product between visual and text
|
| 171 |
+
vt_dot_products = torch.einsum("ij,ij->i", vt_outputs.visual_embeds, vt_outputs.visual_text_embeds)
|
| 172 |
+
# Compute dot product between audio_visual and text
|
| 173 |
+
avt_dot_products = torch.einsum("ij,ij->i", avt_outputs.audio_visual_embeds, avt_outputs.audio_visual_text_embeds)
|
| 174 |
+
# Compute dot product between audio and text
|
| 175 |
+
at_dot_products = torch.einsum("ij,ij->i", at_outputs.audio_embeds, at_outputs.audio_text_embeds)
|
| 176 |
+
# Compute dot product between audio and video
|
| 177 |
+
av_dot_products = torch.einsum("ij,ij->i", avt_outputs.audio_embeds, avt_outputs.video_embeds)
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### Getting Started with PE-A-Frame
|
| 181 |
+
|
| 182 |
+
```python
|
| 183 |
+
from core.audio_visual_encoder import (
|
| 184 |
+
PEAudioFrame,
|
| 185 |
+
PEAudioFrameTransform,
|
| 186 |
+
)
|
| 187 |
+
import torch
|
| 188 |
+
|
| 189 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 190 |
+
model = PEAudioFrame.from_config("pe-a-frame-large", pretrained=True).to(device)
|
| 191 |
+
transform = PEAudioFrameTransform.from_config("pe-a-frame-large")
|
| 192 |
+
|
| 193 |
+
descriptions = ["a person talking"]
|
| 194 |
+
inputs = transform(
|
| 195 |
+
audio=["assets/office.mp4"],
|
| 196 |
+
text=descriptions,
|
| 197 |
+
).to(device)
|
| 198 |
+
|
| 199 |
+
with torch.inference_mode():
|
| 200 |
+
outputs = model(**inputs)
|
| 201 |
+
|
| 202 |
+
# Print the spans for each description (start and end timestamps for when they occur in the audio)
|
| 203 |
+
for description, spans in zip(descriptions, outputs.spans):
|
| 204 |
+
span_str = ", ".join([f"({start:.2f}, {end:.2f})" for start, end in spans])
|
| 205 |
+
print(f'"{description}": [{span_str}]')
|
| 206 |
+
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
> [!TIP]
|
| 210 |
+
> See [`apps/pe/README.md`](apps/pe/README.md) for additional details!
|
| 211 |
+
|
| 212 |
+
## Perception Language Model (PLM)
|
| 213 |
+
[](https://huggingface.co/datasets/facebook/PLM-Video-Human)
|
| 214 |
+
[](https://huggingface.co/collections/facebook/perception-lm-67f9783f171948c383ee7498)
|
| 215 |
+
[](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding)
|
| 216 |
+
[](https://arxiv.org/abs/2504.13180)
|
| 217 |
+
[](apps/plm/notebook_demos)
|
| 218 |
+
[](LICENSE.PLM)
|
| 219 |
+
|
| 220 |
+
PerceptionLM (PLM) is a family of open and fully reproducible models to facilitate research in vision-language modeling (VLM). In conjunction with PE, it is powerful enough to compete with the latest state-of-the-art VLMs such as InternVL3 and QwenVL2.5, while using _fully open data_. We also release the largest spatiotemporally annotated video dense captioning and fine-grained human activity recognition datasets to ever exist.
|
| 221 |
+
|
| 222 |
+

|
| 223 |
+
|
| 224 |
+
### Models
|
| 225 |
+
PLM releases models in three different sizes (1B, 3B and 8B).
|
| 226 |
+
* [Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B): A PLM model trained using Llama-3.2-1B-Instruct base LLM.
|
| 227 |
+
* [Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B): A PLM model trained using Llama-3.2-3B-Instruct base LLM.
|
| 228 |
+
* [Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B): A PLM model trained using Llama-3.1-8B-Instruct base LLM.
|
| 229 |
+
|
| 230 |
+
#### PLM Image Benchmark Results
|
| 231 |
+
|
| 232 |
+
| Model | DocVQA | ChartQA | TextVQA | InfoQA | AI2D | OCRBench | COCO | Nocap | Flickr | MMMU | VQAv2 | OKVQA | VizWiz | MME | SEED | BLINK | CVBench | RealWorldQA | VSR | POPE |
|
| 233 |
+
|:---------:|:--------:|:---------:|:---------:|:--------:|:------:|:----------:|:------------:|:-------------:|:--------------:|:------:|:-------:|:--------:|:--------:|:-----:|:------:|:-------:|:----------:|:-------------:|:-----:|:------:|
|
| 234 |
+
| PLM1B | 90.7 | 78.6 | 82.1 | 63.0 | 84.9 | 807 | 138.6 | 124.2 | 100.5 | 34.8 | 81.7 | 61.0 | 59.7 | 1603| 76.3 | 46.8 | 73.8 | 67.1 | 68.8| 88.4 |
|
| 235 |
+
| PLM3B | 93.8 | 84.3 | 84.3 | 74.6 | 90.9 | 830 | 144.9 | 126.5 | 98.0 | 41.2 | 84.3 | 66.8 | 64.0 | 1879| 78.5 | 55.4 | 81.4 | 72.4 | 80.4| 88.7 |
|
| 236 |
+
| PLM8B | 94.6 | 85.5 | 86.5 | 80.9 | 92.7 | 870 | 146.7 | 129.9 | 105.6 | 46.1 | 85.6 | 69.6 | 67.0 | 1989| 79.3 | 56.0 | 81.3 | 75.0 | 82.8| 89.9 |
|
| 237 |
+
|
| 238 |
+
#### PLM Video Benchmark Results
|
| 239 |
+
|
| 240 |
+
| Model | VATEX | DREAM 1K | How2QA | MVBench | NExTQA | PerceptionTest (test) | STAR | TVQA | VideoMME | TVBench | ActivityNetQA | EgoSchema (test) | TemporalBench | TOMATO | MotionBench (dev) | TempCompass (MCQ) | CGBench (clue) | Charades STA | VideoHallucer | Halluc. EventHallusion |
|
| 241 |
+
|:-------------:|:---------------------------:|:-----------------------:|:---------------------:|:-------------:|:-------------:|:--------------------------:|:----------:|:----------:|:----------------:|:-------------:|:--------------------:|:----------------------:|:---------------------:|:------------:|:------------------------:|:-----------------------:|:---------------------:|:-------------------:|:-------------------------------:|:--------------------------------:|
|
| 242 |
+
| PLM1B | 92.5 | 34.3 | 86.4 | 70.1 | 80.3 | 72.7 | 83.7 | 50.3 | 49.2 | 50.4 | 62.5 | 60.4 | 18.2 | 25.5 | 52.2 | 64.6 | 43.6 | 55.2 | 49.2 | 79.5 |
|
| 243 |
+
| PLM3B | 96.1 | 37.4 | 89.4 | 74.7 | 83.4 | 79.3 | 84.8 | 55.3 | 54.9 | 58.9 | 66.2 | 66.9 | 23.4 | 30.9 | 60.4 | 69.3 | 47.2 | 57.7 | 55.5 | 76.5 |
|
| 244 |
+
| PLM8B | 99.7 | 35.9 | 90.7 | 77.1 | 84.1 | 82.7 | 84.9 | 59.3 | 58.3 | 63.5 | 67.3 | 68.8 | 28.3 | 33.2 | 61.4 | 72.7 | 46.4 | 58.6 | 57.7 | 77.3 |
|
| 245 |
+
|
| 246 |
+
### PLM Resources
|
| 247 |
+
|
| 248 |
+
| Resource | Description | Documentation |
|
| 249 |
+
| --- | --- |--------------------------------------------------------|
|
| 250 |
+
| **Evaluation** | Evaluation of PLM using lmms-eval | [`docs/evaluation.md`](apps/plm/docs/evaluation.md) |
|
| 251 |
+
| **Training / Finetuning** | Training and finetuning instructions for PLM | [`docs/training.md`](apps/plm/docs/training.md) |
|
| 252 |
+
| **PLM-VideoBench** | Evaluation on PLM-VideoBench using lmms-eval | [`docs/plm_videobench.md`](apps/plm/docs/plm_videobench.md) |
|
| 253 |
+
| **End-to-End Finetuning Example** | End-to-end finetuning example on radiology images | [`docs/finetune_example.md`](apps/plm/docs/finetune_example.md) |
|
| 254 |
+
| **Generating Response** | Generate responses using a trained model with `generate.py` | [`generate.py`](apps/plm/generate.py) |
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
> [!TIP]
|
| 258 |
+
> See [`apps/plm/README.md`](apps/plm/README.md) for details and how to get started!
|
| 259 |
+
|
| 260 |
+
## Dataset Releases
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
### 🎥 [PE-Video-Dataset (PVD)](https://huggingface.co/datasets/facebook/PE-Video)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
PVD comprises 1M high quality and diverse videos. Among them, 120K videos are accompanied by automated and human-verified annotations. and all videos are accompanied with video description and keywords. The videos are motion-centered, covering both first-person and third-person views with a wide coverage of scenes.
|
| 267 |
+
|
| 268 |
+
🔹 [**PVD**](https://huggingface.co/datasets/facebook/PE-Video) - 1M High-Quality Human Annotated Video Dataset
|
| 269 |
+
|
| 270 |
+
<table>
|
| 271 |
+
<tr>
|
| 272 |
+
<td colspan="2" align="center"><strong>PVD</strong></td>
|
| 273 |
+
</tr>
|
| 274 |
+
<tr>
|
| 275 |
+
<td align="center">
|
| 276 |
+
<img src="https://github.com/user-attachments/assets/ead8a7ed-4d5b-465a-a396-68948683dfcf" alt="output_2" width="300"/><br>
|
| 277 |
+
A person's hands pruning a plant with green leaves.
|
| 278 |
+
</td>
|
| 279 |
+
<td align="center">
|
| 280 |
+
<img src="https://github.com/user-attachments/assets/9e509e49-f550-4c5c-9571-ed57c5118227" alt="output" width="300"/><br>
|
| 281 |
+
A detailed diorama of a rural landscape featuring a horse-drawn carriage moving along a dirt path
|
| 282 |
+
</td>
|
| 283 |
+
</tr>
|
| 284 |
+
</table>
|
| 285 |
+
|
| 286 |
+
---
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
### 🎥 [PLM-Video-Human](https://huggingface.co/datasets/facebook/PLM-Video-Human)
|
| 290 |
+
|
| 291 |
+
PLM-Video-Human is a collection of human-annotated resources for training Vision Language Models, focused on detailed video understanding. Training tasks include:
|
| 292 |
+
|
| 293 |
+
🔹 [**FGQA**](https://huggingface.co/datasets/facebook/PLM-Video-Human#fine-grained-question-answering-fgqa) — Fine-Grained Question Answering
|
| 294 |
+
🔹 [**RTLoc**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-temporal-localization-rtloc) — Region-Temporal Localization
|
| 295 |
+
🔹 [**RCap**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-video-captioning-rcap) — Region Video Captioning
|
| 296 |
+
🔹 [**RDCap**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-dense-temporal-captioning-rdcap) — Region Dense Temporal Captioning
|
| 297 |
+
|
| 298 |
+
<table>
|
| 299 |
+
<tr>
|
| 300 |
+
<td colspan="2" align="center"><strong>FGQA</strong></td>
|
| 301 |
+
</tr>
|
| 302 |
+
<tr>
|
| 303 |
+
<td colspan="2" align="center">
|
| 304 |
+
<img src="https://github.com/user-attachments/assets/4f5c6c5e-687d-49df-9bf8-db9ec7f1f281" alt="fgqa" width="500"/>
|
| 305 |
+
</td>
|
| 306 |
+
</tr>
|
| 307 |
+
<tr>
|
| 308 |
+
<th>Question</th>
|
| 309 |
+
<th>Answer</th>
|
| 310 |
+
</tr>
|
| 311 |
+
<tr>
|
| 312 |
+
<td>In what direction do you move the tool while removing the shell?</td>
|
| 313 |
+
<td>Both clockwise and anticlockwise.</td>
|
| 314 |
+
</tr>
|
| 315 |
+
</table>
|
| 316 |
+
|
| 317 |
+
<table>
|
| 318 |
+
<tr>
|
| 319 |
+
<td colspan="2" align="center"><strong>STC</strong></td>
|
| 320 |
+
</tr>
|
| 321 |
+
<tr>
|
| 322 |
+
<td colspan="2" align="center">
|
| 323 |
+
<img src="https://github.com/user-attachments/assets/a2a129c7-c1e9-47b5-a3b4-fc96a237a9fb" alt="stc" width="500"/>
|
| 324 |
+
</td>
|
| 325 |
+
</tr>
|
| 326 |
+
<tr>
|
| 327 |
+
<th>Time (s) </th>
|
| 328 |
+
<th>Description</th>
|
| 329 |
+
</tr>
|
| 330 |
+
<tr>
|
| 331 |
+
<td>[0, 4]</td>
|
| 332 |
+
<td>The masked subject is a young boy wearing a red jacket and gray pants. He is grasping a monkey bar–like activity in a playground.</td>
|
| 333 |
+
</tr>
|
| 334 |
+
<tr>
|
| 335 |
+
<td>[5, 14]</td>
|
| 336 |
+
<td>He lets go of his hands and runs to the right side of the frame.</td>
|
| 337 |
+
</tr>
|
| 338 |
+
<tr>
|
| 339 |
+
<td>[15, 30]</td>
|
| 340 |
+
<td>The subject is out of frame.</td>
|
| 341 |
+
</tr>
|
| 342 |
+
<tr>
|
| 343 |
+
<td>[31, 45]</td>
|
| 344 |
+
<td>The subject runs back into the frame toward the higher monkey bar in the playground.</td>
|
| 345 |
+
</tr>
|
| 346 |
+
<tr>
|
| 347 |
+
<td>[46, 74]</td>
|
| 348 |
+
<td>He jumps underneath the metal bar and looks up at it. A man wearing a white polo runs toward the subject.</td>
|
| 349 |
+
</tr>
|
| 350 |
+
<tr>
|
| 351 |
+
<td>[75, 116]</td>
|
| 352 |
+
<td>The man in the white polo lifts the subject upward so he can grasp the higher metal bar. The subject holds onto the bar and hangs from it.</td>
|
| 353 |
+
</tr>
|
| 354 |
+
</table>
|
| 355 |
+
|
| 356 |
+
---
|
| 357 |
+
|
| 358 |
+
### 🤖 Auto-Generated Datasets
|
| 359 |
+
|
| 360 |
+
Sythetic image/video captions and QAs used in PLM, please refer to the paper, Section 3 (PLM), for more details. The sythetic annotations covers: SA1B, Openimages, Obejct365, ArxivQA, UCSF, PDFAcc, YT-1B, Ego4d with captions, YT-1B with MCQAs and Ego4d with QAs.
|
| 361 |
+
|
| 362 |
+
🖼️ [**PLM-Image-Auto**](https://huggingface.co/datasets/facebook/PLM-Image-Auto) — Automatically generated image datasets
|
| 363 |
+
|
| 364 |
+
📹 [**PLM-Video-Auto**](https://huggingface.co/datasets/facebook/PLM-Video-Auto) — Automatically generated video datasets
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
+
## Installation :wrench:
|
| 370 |
+
```shell
|
| 371 |
+
git clone https://github.com/facebookresearch/perception_models.git
|
| 372 |
+
cd perception_models
|
| 373 |
+
|
| 374 |
+
conda create --name perception_models python=3.12
|
| 375 |
+
conda activate perception_models
|
| 376 |
+
|
| 377 |
+
# Install PyTorch
|
| 378 |
+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 xformers --index-url https://download.pytorch.org/whl/cu124
|
| 379 |
+
|
| 380 |
+
# We use torchcodec for decoding videos into PyTorch tensors
|
| 381 |
+
conda install ffmpeg -c conda-forge
|
| 382 |
+
pip install torchcodec==0.1 --index-url=https://download.pytorch.org/whl/cu124
|
| 383 |
+
|
| 384 |
+
pip install -e .
|
| 385 |
+
```
|
| 386 |
+
This will install an editable version of repo, allowing you to make changes to the code without needing to reinstall the package every time.
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
## 🙏 Acknowledgement
|
| 390 |
+
We are thankful to [Meta Lingua](https://github.com/facebookresearch/lingua) for releasing their code as open-source contributions. The code structure and code implementation of the LLM is directly forked from [Meta Lingua](https://github.com/facebookresearch/lingua). We are also thankful to [Open_CLIP](https://github.com/mlfoundations/open_clip) for open-source contributions in CLIP training, and [CLIP_benchmark](https://github.com/LAION-AI/CLIP_benchmark) for CLIP model evaluation.
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
## 📜 Citation
|
| 394 |
+
```BibTeX
|
| 395 |
+
@article{bolya2025PerceptionEncoder,
|
| 396 |
+
title={Perception Encoder: The best visual embeddings are not at the output of the network},
|
| 397 |
+
author={Daniel Bolya and Po-Yao Huang and Peize Sun and Jang Hyun Cho and Andrea Madotto and Chen Wei and Tengyu Ma and Jiale Zhi and Jathushan Rajasegaran and Hanoona Rasheed and Junke Wang and Marco Monteiro and Hu Xu and Shiyu Dong and Nikhila Ravi and Daniel Li and Piotr Doll{\'a}r and Christoph Feichtenhofer},
|
| 398 |
+
journal={arXiv:2504.13181},
|
| 399 |
+
year={2025}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
@article{cho2025PerceptionLM,
|
| 403 |
+
title={PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding},
|
| 404 |
+
author={Jang Hyun Cho and Andrea Madotto and Effrosyni Mavroudi and Triantafyllos Afouras and Tushar Nagarajan and Muhammad Maaz and Yale Song and Tengyu Ma and Shuming Hu and Hanoona Rasheed and Peize Sun and Po-Yao Huang and Daniel Bolya and Suyog Jain and Miguel Martin and Huiyu Wang and Nikhila Ravi and Shashank Jain and Temmy Stark and Shane Moon and Babak Damavandi and Vivian Lee and Andrew Westbury and Salman Khan and Philipp Kr\"{a}henb\"{u}hl and Piotr Doll{\'a}r and Lorenzo Torresani and Kristen Grauman and Christoph Feichtenhofer},
|
| 405 |
+
journal={arXiv:2504.13180},
|
| 406 |
+
year={2025}
|
| 407 |
+
}
|
| 408 |
+
```
|
perception_models/__pycache__/legrad_pe_audio.cpython-310.pyc
ADDED
|
Binary file (6.49 kB). View file
|
|
|
perception_models/__pycache__/legrad_pe_audio.cpython-313.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
perception_models/__pycache__/legrad_pe_image.cpython-312.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
perception_models/__pycache__/legrad_pe_image.cpython-313.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
perception_models/apps/detection/DETA_pe/README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SOTA COCO Object Detection with PE
|
| 2 |
+
|
| 3 |
+
## Getting started
|
| 4 |
+
|
| 5 |
+
Please refer to [INSTALL.md](../INSTALL.md) for installation and dataset preparation instructions.
|
| 6 |
+
|
| 7 |
+
Also install [Deformable Attention](models/ops/make.sh) ops.
|
| 8 |
+
|
| 9 |
+
## Results and Fine-tuned Models
|
| 10 |
+
|
| 11 |
+
<table><tbody>
|
| 12 |
+
<!-- START TABLE -->
|
| 13 |
+
<!-- TABLE HEADER -->
|
| 14 |
+
<th valign="bottom">detector</th>
|
| 15 |
+
<th valign="bottom">vision encoder</th>
|
| 16 |
+
<th valign="bottom">box<br/>AP</th>
|
| 17 |
+
<th valign="bottom">box(TTA)<br/>AP</th>
|
| 18 |
+
<th valign="bottom">download</th>
|
| 19 |
+
<!-- TABLE BODY -->
|
| 20 |
+
<!-- ROW: DETA -->
|
| 21 |
+
<tr><td align="left">DETA</td>
|
| 22 |
+
<td align="center">PE spatial G</td>
|
| 23 |
+
<td align="center"> 65.2 </td>
|
| 24 |
+
<td align="center"> 66.0 </td>
|
| 25 |
+
<td align="center"><a href="https://huggingface.co/facebook/PE-Detection/resolve/main/deta_coco_1824pix.pth">model</a></td>
|
| 26 |
+
</tr>
|
| 27 |
+
</tbody></table>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Training
|
| 31 |
+
We apply a four-stage training, Objects365(12ep, 1024pix), Objects365(6ep, 1536pix), COCO(12ep, 1728pix), COCO(3ep, 1824pix)
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
sbatch scripts/pretrain_spatial_Gwin384_o365ep12_1024pix_16node.sh
|
| 35 |
+
|
| 36 |
+
sbatch scripts/pretrain_continue_spatial_Gwin384_o365ep6_1536pix_16node.sh
|
| 37 |
+
|
| 38 |
+
sbatch scripts/finetune_spatial_Gwin384_cocoep12_1728pix_8node.sh
|
| 39 |
+
|
| 40 |
+
sbatch scripts/finetune_further_spatial_Gwin384_cocoep3_1824pix_8node.sh
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Evaluation
|
| 45 |
+
```
|
| 46 |
+
bash scripts/eval_1824pix.sh --resume deta_coco_1824pix.pth
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Evaluation with TTA (Test-Time Augmentation)
|
| 50 |
+
```
|
| 51 |
+
sbatch scripts/eval_tta_slurm_1824pix.sh --resume deta_coco_1824pix.pth
|
| 52 |
+
```
|
| 53 |
+
Note: If you get 65.9 AP, it is probably caused by different package versions, trying different hyperparameters like `--quad_scale 0.4` will give 66.0 AP.
|
perception_models/apps/detection/DETA_pe/datasets/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
|
| 12 |
+
from .coco import build as build_coco
|
| 13 |
+
from .objects365 import build as build_objects365
|
| 14 |
+
from .torchvision_datasets import CocoDetection
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_coco_api_from_dataset(dataset):
|
| 18 |
+
for _ in range(10):
|
| 19 |
+
# if isinstance(dataset, torchvision.datasets.CocoDetection):
|
| 20 |
+
# break
|
| 21 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
| 22 |
+
dataset = dataset.dataset
|
| 23 |
+
if isinstance(dataset, CocoDetection):
|
| 24 |
+
return dataset.coco
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_dataset(image_set, args):
|
| 28 |
+
if args.dataset_file == "objects365":
|
| 29 |
+
return build_objects365(image_set, args)
|
| 30 |
+
if args.dataset_file == "coco":
|
| 31 |
+
return build_coco(image_set, args)
|
| 32 |
+
if args.dataset_file == "coco_panoptic":
|
| 33 |
+
# to avoid making panopticapi required for coco
|
| 34 |
+
from .coco_panoptic import build as build_coco_panoptic
|
| 35 |
+
|
| 36 |
+
return build_coco_panoptic(image_set, args)
|
| 37 |
+
raise ValueError(f"dataset {args.dataset_file} not supported")
|
perception_models/apps/detection/DETA_pe/datasets/coco.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
COCO dataset which returns image_id for evaluation.
|
| 12 |
+
|
| 13 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
| 14 |
+
"""
|
| 15 |
+
import random
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import datasets.transforms as T
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.data
|
| 21 |
+
import torchvision.transforms.functional as F
|
| 22 |
+
from pycocotools import mask as coco_mask
|
| 23 |
+
from util.misc import get_local_rank, get_local_size
|
| 24 |
+
|
| 25 |
+
from .torchvision_datasets import CocoDetection as TvCocoDetection
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CocoDetection(TvCocoDetection):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
img_folder,
|
| 33 |
+
ann_file,
|
| 34 |
+
transforms,
|
| 35 |
+
return_masks,
|
| 36 |
+
cache_mode=False,
|
| 37 |
+
local_rank=0,
|
| 38 |
+
local_size=1,
|
| 39 |
+
test_hflip_aug=False,
|
| 40 |
+
tta=False,
|
| 41 |
+
is_train=False,
|
| 42 |
+
lsj_img_size=1824,
|
| 43 |
+
):
|
| 44 |
+
super(CocoDetection, self).__init__(
|
| 45 |
+
img_folder,
|
| 46 |
+
ann_file,
|
| 47 |
+
cache_mode=cache_mode,
|
| 48 |
+
local_rank=local_rank,
|
| 49 |
+
local_size=local_size,
|
| 50 |
+
)
|
| 51 |
+
self._transforms = transforms
|
| 52 |
+
self.prepare = ConvertCocoPolysToMask(return_masks)
|
| 53 |
+
self.test_hflip_aug = test_hflip_aug
|
| 54 |
+
self.tta = tta
|
| 55 |
+
if lsj_img_size == 1728: # for back-compatibility
|
| 56 |
+
self.tta_image_size = [1536, 1152,]
|
| 57 |
+
else:
|
| 58 |
+
self.tta_image_size = [1728, 1536, 1344,]
|
| 59 |
+
|
| 60 |
+
self.is_train = is_train
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, idx):
|
| 63 |
+
img, target = super(CocoDetection, self).__getitem__(idx)
|
| 64 |
+
image_id = self.ids[idx]
|
| 65 |
+
target = {"image_id": image_id, "annotations": target}
|
| 66 |
+
img, target = self.prepare(img, target)
|
| 67 |
+
if self._transforms is not None:
|
| 68 |
+
img, target = self._transforms(img, target)
|
| 69 |
+
|
| 70 |
+
if self.test_hflip_aug:
|
| 71 |
+
flipped_img = torch.flip(img, dims=[-1])
|
| 72 |
+
new_img = torch.cat([img, flipped_img], dim=0)
|
| 73 |
+
return new_img, target
|
| 74 |
+
|
| 75 |
+
elif self.tta:
|
| 76 |
+
tta_images = [img]
|
| 77 |
+
flipped_img = torch.flip(img, dims=[-1])
|
| 78 |
+
tta_images.append(flipped_img)
|
| 79 |
+
_, height, width = img.shape
|
| 80 |
+
max_size_len = height if height >= width else width
|
| 81 |
+
for new_max_size in self.tta_image_size:
|
| 82 |
+
scale = new_max_size / max_size_len
|
| 83 |
+
new_height, new_width = int(scale * height), int(scale * width)
|
| 84 |
+
new_img = F.resize(img, size=(new_height, new_width))
|
| 85 |
+
tta_images.append(new_img)
|
| 86 |
+
flipped_img = torch.flip(new_img, dims=[-1])
|
| 87 |
+
tta_images.append(flipped_img)
|
| 88 |
+
return tta_images, target
|
| 89 |
+
else:
|
| 90 |
+
return img, target
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
| 94 |
+
masks = []
|
| 95 |
+
for polygons in segmentations:
|
| 96 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 97 |
+
mask = coco_mask.decode(rles)
|
| 98 |
+
if len(mask.shape) < 3:
|
| 99 |
+
mask = mask[..., None]
|
| 100 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
| 101 |
+
mask = mask.any(dim=2)
|
| 102 |
+
masks.append(mask)
|
| 103 |
+
if masks:
|
| 104 |
+
masks = torch.stack(masks, dim=0)
|
| 105 |
+
else:
|
| 106 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
| 107 |
+
return masks
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class ConvertCocoPolysToMask(object):
|
| 111 |
+
def __init__(self, return_masks=False):
|
| 112 |
+
self.return_masks = return_masks
|
| 113 |
+
|
| 114 |
+
def __call__(self, image, target):
|
| 115 |
+
w, h = image.size
|
| 116 |
+
|
| 117 |
+
image_id = target["image_id"]
|
| 118 |
+
image_id = torch.tensor([image_id])
|
| 119 |
+
|
| 120 |
+
anno = target["annotations"]
|
| 121 |
+
|
| 122 |
+
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
| 123 |
+
|
| 124 |
+
boxes = [obj["bbox"] for obj in anno]
|
| 125 |
+
# guard against no boxes via resizing
|
| 126 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
| 127 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 128 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
| 129 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
| 130 |
+
|
| 131 |
+
classes = [obj["category_id"] for obj in anno]
|
| 132 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 133 |
+
|
| 134 |
+
if self.return_masks:
|
| 135 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
| 136 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
| 137 |
+
|
| 138 |
+
keypoints = None
|
| 139 |
+
if anno and "keypoints" in anno[0]:
|
| 140 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
| 141 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
| 142 |
+
num_keypoints = keypoints.shape[0]
|
| 143 |
+
if num_keypoints:
|
| 144 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
| 145 |
+
|
| 146 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 147 |
+
boxes = boxes[keep]
|
| 148 |
+
classes = classes[keep]
|
| 149 |
+
if self.return_masks:
|
| 150 |
+
masks = masks[keep]
|
| 151 |
+
if keypoints is not None:
|
| 152 |
+
keypoints = keypoints[keep]
|
| 153 |
+
|
| 154 |
+
target = {}
|
| 155 |
+
target["boxes"] = boxes
|
| 156 |
+
target["labels"] = classes
|
| 157 |
+
if self.return_masks:
|
| 158 |
+
target["masks"] = masks
|
| 159 |
+
target["image_id"] = image_id
|
| 160 |
+
if keypoints is not None:
|
| 161 |
+
target["keypoints"] = keypoints
|
| 162 |
+
|
| 163 |
+
# for conversion to coco api
|
| 164 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
| 165 |
+
iscrowd = torch.tensor(
|
| 166 |
+
[obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]
|
| 167 |
+
)
|
| 168 |
+
target["area"] = area[keep]
|
| 169 |
+
target["iscrowd"] = iscrowd[keep]
|
| 170 |
+
|
| 171 |
+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
| 172 |
+
target["size"] = torch.as_tensor([int(h), int(w)])
|
| 173 |
+
|
| 174 |
+
return image, target
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def make_coco_transforms(image_set, bigger):
|
| 178 |
+
|
| 179 |
+
normalize = T.Compose(
|
| 180 |
+
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if "train" in image_set:
|
| 184 |
+
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
| 185 |
+
if "val" in image_set or "test" in image_set:
|
| 186 |
+
scales = [800]
|
| 187 |
+
|
| 188 |
+
max_size = 1333
|
| 189 |
+
if bigger:
|
| 190 |
+
scales = [int(1.5 * s) for s in scales]
|
| 191 |
+
max_size = 2000
|
| 192 |
+
|
| 193 |
+
if image_set == "train":
|
| 194 |
+
augmentation_list = [
|
| 195 |
+
T.RandomHorizontalFlip(),
|
| 196 |
+
T.RandomSelect(
|
| 197 |
+
T.RandomResize(scales, max_size=max_size),
|
| 198 |
+
T.Compose(
|
| 199 |
+
[
|
| 200 |
+
T.RandomResize([400, 500, 600]),
|
| 201 |
+
T.RandomSizeCrop(384, 600),
|
| 202 |
+
T.RandomResize(scales, max_size=max_size),
|
| 203 |
+
]
|
| 204 |
+
),
|
| 205 |
+
),
|
| 206 |
+
normalize,
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
return T.Compose(augmentation_list)
|
| 210 |
+
|
| 211 |
+
if image_set == "val":
|
| 212 |
+
return T.Compose(
|
| 213 |
+
[
|
| 214 |
+
T.RandomResize(scales, max_size=max_size),
|
| 215 |
+
normalize,
|
| 216 |
+
]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
raise ValueError(f"unknown {image_set}")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def make_coco_transforms_lsj(
|
| 223 |
+
image_set, image_size, lsj_img_train_min=480, lsj_strong_aug=False
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Reference: https://github.com/facebookresearch/detectron2/blob/main/projects/ViTDet/configs/common/coco_loader_lsj.py
|
| 227 |
+
|
| 228 |
+
import detectron2.data.transforms as T
|
| 229 |
+
from detectron2 import model_zoo
|
| 230 |
+
from detectron2.config import LazyCall as L
|
| 231 |
+
|
| 232 |
+
# Data using LSJ
|
| 233 |
+
image_size = 1024
|
| 234 |
+
dataloader = model_zoo.get_config("common/data/coco.py").dataloader
|
| 235 |
+
dataloader.train.mapper.augmentations = [
|
| 236 |
+
L(T.RandomFlip)(horizontal=True), # flip first
|
| 237 |
+
L(T.ResizeScale)(
|
| 238 |
+
min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
|
| 239 |
+
),
|
| 240 |
+
L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
|
| 241 |
+
]
|
| 242 |
+
dataloader.train.mapper.image_format = "RGB"
|
| 243 |
+
dataloader.train.total_batch_size = 64
|
| 244 |
+
# recompute boxes due to cropping
|
| 245 |
+
dataloader.train.mapper.recompute_boxes = True
|
| 246 |
+
|
| 247 |
+
dataloader.test.mapper.augmentations = [
|
| 248 |
+
L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
|
| 249 |
+
]
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
In our implementation, we simulate lsj data augmentation by:
|
| 254 |
+
(1) first the following augmentations
|
| 255 |
+
(2) then padding to (image_size, image_size) in collator, see util/misc/collate_fn_lsj.py
|
| 256 |
+
"""
|
| 257 |
+
normalize = T.Compose(
|
| 258 |
+
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if "train" in image_set:
|
| 262 |
+
scales = [scale for scale in range(lsj_img_train_min, image_size, 32)]
|
| 263 |
+
if "val" in image_set or "test" in image_set or "unlabel" in image_set:
|
| 264 |
+
scales = [image_size - 32]
|
| 265 |
+
|
| 266 |
+
# max_size = 1333
|
| 267 |
+
# if bigger:
|
| 268 |
+
# scales = [int(1.5 * s) for s in scales]
|
| 269 |
+
# max_size = 2000
|
| 270 |
+
max_size = image_size - 32 # for some wired bugs
|
| 271 |
+
|
| 272 |
+
augmentation_list = []
|
| 273 |
+
if "train" in image_set:
|
| 274 |
+
if lsj_strong_aug:
|
| 275 |
+
augmentation_list.extend(
|
| 276 |
+
[
|
| 277 |
+
T.ColorJitter((0.4, 0.4, 0.4, 0.1), p=0.5),
|
| 278 |
+
T.RandomGrayscale(p=0.2),
|
| 279 |
+
# T.RandomErasingP05(),
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
augmentation_list.extend(
|
| 283 |
+
[
|
| 284 |
+
T.RandomHorizontalFlip(),
|
| 285 |
+
T.RandomSelect(
|
| 286 |
+
# similar to (T.ResizeScale)(min_scale=0.1, max_scale=1.0, target_height=image_size, target_width=image_size) and pad
|
| 287 |
+
T.RandomResize(scales, max_size=max_size),
|
| 288 |
+
# similar to (T.ResizeScale)(min_scale=1.0, max_scale=2.0, target_height=image_size, target_width=image_size) and crop
|
| 289 |
+
T.Compose(
|
| 290 |
+
[
|
| 291 |
+
T.RandomResize([400, 500, 600]),
|
| 292 |
+
T.RandomSizeCrop(384, 600),
|
| 293 |
+
T.RandomResize([max_size], max_size=max_size),
|
| 294 |
+
]
|
| 295 |
+
),
|
| 296 |
+
),
|
| 297 |
+
normalize,
|
| 298 |
+
]
|
| 299 |
+
)
|
| 300 |
+
return T.Compose(augmentation_list)
|
| 301 |
+
|
| 302 |
+
if image_set == "val":
|
| 303 |
+
return T.Compose(
|
| 304 |
+
[
|
| 305 |
+
T.RandomResize(scales, max_size=max_size),
|
| 306 |
+
normalize,
|
| 307 |
+
]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
raise ValueError(f"unknown {image_set}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def build(image_set, args):
|
| 314 |
+
root = Path(args.coco_path)
|
| 315 |
+
assert root.exists(), f"provided COCO path {root} does not exist"
|
| 316 |
+
mode = "instances"
|
| 317 |
+
PATHS = {
|
| 318 |
+
"train": (root / "train2017", root / "annotations" / f"{mode}_train2017.json"),
|
| 319 |
+
"val": (root / "val2017", root / "annotations" / f"{mode}_val2017.json"),
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
img_folder, ann_file = PATHS[image_set]
|
| 323 |
+
if args.lsj:
|
| 324 |
+
coco_transform = make_coco_transforms_lsj(
|
| 325 |
+
image_set,
|
| 326 |
+
args.lsj_img_size,
|
| 327 |
+
args.lsj_img_train_min,
|
| 328 |
+
args.lsj_strong_aug,
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
coco_transform = make_coco_transforms(image_set, args.bigger)
|
| 332 |
+
dataset = CocoDetection(
|
| 333 |
+
img_folder,
|
| 334 |
+
ann_file,
|
| 335 |
+
transforms=coco_transform,
|
| 336 |
+
return_masks=args.masks,
|
| 337 |
+
cache_mode=args.cache_mode,
|
| 338 |
+
local_rank=get_local_rank(),
|
| 339 |
+
local_size=get_local_size(),
|
| 340 |
+
test_hflip_aug=args.test_hflip_aug,
|
| 341 |
+
tta=args.tta,
|
| 342 |
+
is_train=("train" in image_set),
|
| 343 |
+
lsj_img_size=args.lsj_img_size,
|
| 344 |
+
)
|
| 345 |
+
return dataset
|
perception_models/apps/detection/DETA_pe/datasets/coco_eval.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
COCO evaluator that works in distributed mode.
|
| 12 |
+
|
| 13 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
| 14 |
+
The difference is that there is less copy-pasting from pycocotools
|
| 15 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
| 16 |
+
"""
|
| 17 |
+
import os
|
| 18 |
+
import contextlib
|
| 19 |
+
import copy
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from pycocotools.cocoeval import COCOeval
|
| 24 |
+
from pycocotools.coco import COCO
|
| 25 |
+
import pycocotools.mask as mask_util
|
| 26 |
+
|
| 27 |
+
from util.misc import all_gather
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CocoEvaluator(object):
|
| 31 |
+
def __init__(self, coco_gt, iou_types):
|
| 32 |
+
assert isinstance(iou_types, (list, tuple))
|
| 33 |
+
coco_gt = copy.deepcopy(coco_gt)
|
| 34 |
+
self.coco_gt = coco_gt
|
| 35 |
+
|
| 36 |
+
self.iou_types = iou_types
|
| 37 |
+
self.coco_eval = {}
|
| 38 |
+
for iou_type in iou_types:
|
| 39 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
| 40 |
+
|
| 41 |
+
self.img_ids = []
|
| 42 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
| 43 |
+
|
| 44 |
+
def update(self, predictions):
|
| 45 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
| 46 |
+
self.img_ids.extend(img_ids)
|
| 47 |
+
|
| 48 |
+
for iou_type in self.iou_types:
|
| 49 |
+
results = self.prepare(predictions, iou_type)
|
| 50 |
+
|
| 51 |
+
# suppress pycocotools prints
|
| 52 |
+
with open(os.devnull, 'w') as devnull:
|
| 53 |
+
with contextlib.redirect_stdout(devnull):
|
| 54 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
| 55 |
+
coco_eval = self.coco_eval[iou_type]
|
| 56 |
+
|
| 57 |
+
coco_eval.cocoDt = coco_dt
|
| 58 |
+
coco_eval.params.imgIds = list(img_ids)
|
| 59 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
| 60 |
+
|
| 61 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
| 62 |
+
|
| 63 |
+
def synchronize_between_processes(self):
|
| 64 |
+
for iou_type in self.iou_types:
|
| 65 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
| 66 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
| 67 |
+
|
| 68 |
+
def accumulate(self):
|
| 69 |
+
for coco_eval in self.coco_eval.values():
|
| 70 |
+
coco_eval.accumulate()
|
| 71 |
+
|
| 72 |
+
def summarize(self):
|
| 73 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
| 74 |
+
print("IoU metric: {}".format(iou_type))
|
| 75 |
+
coco_eval.summarize()
|
| 76 |
+
|
| 77 |
+
def prepare(self, predictions, iou_type):
|
| 78 |
+
if iou_type == "bbox":
|
| 79 |
+
return self.prepare_for_coco_detection(predictions)
|
| 80 |
+
elif iou_type == "segm":
|
| 81 |
+
return self.prepare_for_coco_segmentation(predictions)
|
| 82 |
+
elif iou_type == "keypoints":
|
| 83 |
+
return self.prepare_for_coco_keypoint(predictions)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
| 86 |
+
|
| 87 |
+
def prepare_for_coco_detection(self, predictions):
|
| 88 |
+
coco_results = []
|
| 89 |
+
for original_id, prediction in predictions.items():
|
| 90 |
+
if len(prediction) == 0:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
boxes = prediction["boxes"]
|
| 94 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 95 |
+
scores = prediction["scores"].tolist()
|
| 96 |
+
labels = prediction["labels"].tolist()
|
| 97 |
+
|
| 98 |
+
coco_results.extend(
|
| 99 |
+
[
|
| 100 |
+
{
|
| 101 |
+
"image_id": original_id,
|
| 102 |
+
"category_id": labels[k],
|
| 103 |
+
"bbox": box,
|
| 104 |
+
"score": scores[k],
|
| 105 |
+
}
|
| 106 |
+
for k, box in enumerate(boxes)
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
return coco_results
|
| 110 |
+
|
| 111 |
+
def prepare_for_coco_segmentation(self, predictions):
|
| 112 |
+
coco_results = []
|
| 113 |
+
for original_id, prediction in predictions.items():
|
| 114 |
+
if len(prediction) == 0:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
scores = prediction["scores"]
|
| 118 |
+
labels = prediction["labels"]
|
| 119 |
+
masks = prediction["masks"]
|
| 120 |
+
|
| 121 |
+
masks = masks > 0.5
|
| 122 |
+
|
| 123 |
+
scores = prediction["scores"].tolist()
|
| 124 |
+
labels = prediction["labels"].tolist()
|
| 125 |
+
|
| 126 |
+
rles = [
|
| 127 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
| 128 |
+
for mask in masks
|
| 129 |
+
]
|
| 130 |
+
for rle in rles:
|
| 131 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 132 |
+
|
| 133 |
+
coco_results.extend(
|
| 134 |
+
[
|
| 135 |
+
{
|
| 136 |
+
"image_id": original_id,
|
| 137 |
+
"category_id": labels[k],
|
| 138 |
+
"segmentation": rle,
|
| 139 |
+
"score": scores[k],
|
| 140 |
+
}
|
| 141 |
+
for k, rle in enumerate(rles)
|
| 142 |
+
]
|
| 143 |
+
)
|
| 144 |
+
return coco_results
|
| 145 |
+
|
| 146 |
+
def prepare_for_coco_keypoint(self, predictions):
|
| 147 |
+
coco_results = []
|
| 148 |
+
for original_id, prediction in predictions.items():
|
| 149 |
+
if len(prediction) == 0:
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
boxes = prediction["boxes"]
|
| 153 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 154 |
+
scores = prediction["scores"].tolist()
|
| 155 |
+
labels = prediction["labels"].tolist()
|
| 156 |
+
keypoints = prediction["keypoints"]
|
| 157 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
| 158 |
+
|
| 159 |
+
coco_results.extend(
|
| 160 |
+
[
|
| 161 |
+
{
|
| 162 |
+
"image_id": original_id,
|
| 163 |
+
"category_id": labels[k],
|
| 164 |
+
'keypoints': keypoint,
|
| 165 |
+
"score": scores[k],
|
| 166 |
+
}
|
| 167 |
+
for k, keypoint in enumerate(keypoints)
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
return coco_results
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def convert_to_xywh(boxes):
|
| 174 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
| 175 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def merge(img_ids, eval_imgs):
|
| 179 |
+
all_img_ids = all_gather(img_ids)
|
| 180 |
+
all_eval_imgs = all_gather(eval_imgs)
|
| 181 |
+
|
| 182 |
+
merged_img_ids = []
|
| 183 |
+
for p in all_img_ids:
|
| 184 |
+
merged_img_ids.extend(p)
|
| 185 |
+
|
| 186 |
+
merged_eval_imgs = []
|
| 187 |
+
for p in all_eval_imgs:
|
| 188 |
+
merged_eval_imgs.append(p)
|
| 189 |
+
|
| 190 |
+
merged_img_ids = np.array(merged_img_ids)
|
| 191 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
| 192 |
+
|
| 193 |
+
# keep only unique (and in sorted order) images
|
| 194 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
| 195 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
| 196 |
+
|
| 197 |
+
return merged_img_ids, merged_eval_imgs
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
| 201 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
| 202 |
+
img_ids = list(img_ids)
|
| 203 |
+
eval_imgs = list(eval_imgs.flatten())
|
| 204 |
+
|
| 205 |
+
coco_eval.evalImgs = eval_imgs
|
| 206 |
+
coco_eval.params.imgIds = img_ids
|
| 207 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
#################################################################
|
| 211 |
+
# From pycocotools, just removed the prints and fixed
|
| 212 |
+
# a Python3 bug about unicode not defined
|
| 213 |
+
#################################################################
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def evaluate(self):
|
| 217 |
+
'''
|
| 218 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 219 |
+
:return: None
|
| 220 |
+
'''
|
| 221 |
+
# tic = time.time()
|
| 222 |
+
# print('Running per image evaluation...')
|
| 223 |
+
p = self.params
|
| 224 |
+
# add backward compatibility if useSegm is specified in params
|
| 225 |
+
if p.useSegm is not None:
|
| 226 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
| 227 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
| 228 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
| 229 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 230 |
+
if p.useCats:
|
| 231 |
+
p.catIds = list(np.unique(p.catIds))
|
| 232 |
+
p.maxDets = sorted(p.maxDets)
|
| 233 |
+
self.params = p
|
| 234 |
+
|
| 235 |
+
self._prepare()
|
| 236 |
+
# loop through images, area range, max detection number
|
| 237 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 238 |
+
|
| 239 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
| 240 |
+
computeIoU = self.computeIoU
|
| 241 |
+
elif p.iouType == 'keypoints':
|
| 242 |
+
computeIoU = self.computeOks
|
| 243 |
+
self.ious = {
|
| 244 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 245 |
+
for imgId in p.imgIds
|
| 246 |
+
for catId in catIds}
|
| 247 |
+
|
| 248 |
+
evaluateImg = self.evaluateImg
|
| 249 |
+
maxDet = p.maxDets[-1]
|
| 250 |
+
evalImgs = [
|
| 251 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
| 252 |
+
for catId in catIds
|
| 253 |
+
for areaRng in p.areaRng
|
| 254 |
+
for imgId in p.imgIds
|
| 255 |
+
]
|
| 256 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 257 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
| 258 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 259 |
+
# toc = time.time()
|
| 260 |
+
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
| 261 |
+
return p.imgIds, evalImgs
|
| 262 |
+
|
| 263 |
+
#################################################################
|
| 264 |
+
# end of straight copy from pycocotools, just removing the prints
|
| 265 |
+
#################################################################
|
perception_models/apps/detection/DETA_pe/datasets/coco_panoptic.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from panopticapi.utils import rgb2id
|
| 18 |
+
from util.box_ops import masks_to_boxes
|
| 19 |
+
|
| 20 |
+
from .coco import make_coco_transforms
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CocoPanoptic:
|
| 24 |
+
def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True):
|
| 25 |
+
with open(ann_file, 'r') as f:
|
| 26 |
+
self.coco = json.load(f)
|
| 27 |
+
|
| 28 |
+
# sort 'images' field so that they are aligned with 'annotations'
|
| 29 |
+
# i.e., in alphabetical order
|
| 30 |
+
self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
|
| 31 |
+
# sanity check
|
| 32 |
+
if "annotations" in self.coco:
|
| 33 |
+
for img, ann in zip(self.coco['images'], self.coco['annotations']):
|
| 34 |
+
assert img['file_name'][:-4] == ann['file_name'][:-4]
|
| 35 |
+
|
| 36 |
+
self.img_folder = img_folder
|
| 37 |
+
self.ann_folder = ann_folder
|
| 38 |
+
self.ann_file = ann_file
|
| 39 |
+
self.transforms = transforms
|
| 40 |
+
self.return_masks = return_masks
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
|
| 44 |
+
img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')
|
| 45 |
+
ann_path = Path(self.ann_folder) / ann_info['file_name']
|
| 46 |
+
|
| 47 |
+
img = Image.open(img_path).convert('RGB')
|
| 48 |
+
w, h = img.size
|
| 49 |
+
if "segments_info" in ann_info:
|
| 50 |
+
masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
|
| 51 |
+
masks = rgb2id(masks)
|
| 52 |
+
|
| 53 |
+
ids = np.array([ann['id'] for ann in ann_info['segments_info']])
|
| 54 |
+
masks = masks == ids[:, None, None]
|
| 55 |
+
|
| 56 |
+
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
| 57 |
+
labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64)
|
| 58 |
+
|
| 59 |
+
target = {}
|
| 60 |
+
target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
|
| 61 |
+
if self.return_masks:
|
| 62 |
+
target['masks'] = masks
|
| 63 |
+
target['labels'] = labels
|
| 64 |
+
|
| 65 |
+
target["boxes"] = masks_to_boxes(masks)
|
| 66 |
+
|
| 67 |
+
target['size'] = torch.as_tensor([int(h), int(w)])
|
| 68 |
+
target['orig_size'] = torch.as_tensor([int(h), int(w)])
|
| 69 |
+
if "segments_info" in ann_info:
|
| 70 |
+
for name in ['iscrowd', 'area']:
|
| 71 |
+
target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']])
|
| 72 |
+
|
| 73 |
+
if self.transforms is not None:
|
| 74 |
+
img, target = self.transforms(img, target)
|
| 75 |
+
|
| 76 |
+
return img, target
|
| 77 |
+
|
| 78 |
+
def __len__(self):
|
| 79 |
+
return len(self.coco['images'])
|
| 80 |
+
|
| 81 |
+
def get_height_and_width(self, idx):
|
| 82 |
+
img_info = self.coco['images'][idx]
|
| 83 |
+
height = img_info['height']
|
| 84 |
+
width = img_info['width']
|
| 85 |
+
return height, width
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build(image_set, args):
|
| 89 |
+
img_folder_root = Path(args.coco_path)
|
| 90 |
+
ann_folder_root = Path(args.coco_panoptic_path)
|
| 91 |
+
assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist'
|
| 92 |
+
assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist'
|
| 93 |
+
mode = 'panoptic'
|
| 94 |
+
PATHS = {
|
| 95 |
+
"train": ("train2017", Path("annotations") / f'{mode}_train2017.json'),
|
| 96 |
+
"val": ("val2017", Path("annotations") / f'{mode}_val2017.json'),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
img_folder, ann_file = PATHS[image_set]
|
| 100 |
+
img_folder_path = img_folder_root / img_folder
|
| 101 |
+
ann_folder = ann_folder_root / f'{mode}_{img_folder}'
|
| 102 |
+
ann_file = ann_folder_root / ann_file
|
| 103 |
+
|
| 104 |
+
dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
|
| 105 |
+
transforms=make_coco_transforms(image_set), return_masks=args.masks)
|
| 106 |
+
|
| 107 |
+
return dataset
|
perception_models/apps/detection/DETA_pe/datasets/data_prefetcher.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def to_cuda(samples, targets, device):
|
| 10 |
+
samples = samples.to(device, non_blocking=True)
|
| 11 |
+
targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
|
| 12 |
+
return samples, targets
|
| 13 |
+
|
| 14 |
+
class data_prefetcher():
|
| 15 |
+
def __init__(self, loader, device, prefetch=True):
|
| 16 |
+
self.loader = iter(loader)
|
| 17 |
+
self.prefetch = prefetch
|
| 18 |
+
self.device = device
|
| 19 |
+
if prefetch:
|
| 20 |
+
self.stream = torch.cuda.Stream()
|
| 21 |
+
self.preload()
|
| 22 |
+
|
| 23 |
+
def preload(self):
|
| 24 |
+
try:
|
| 25 |
+
self.next_samples, self.next_targets = next(self.loader)
|
| 26 |
+
except StopIteration:
|
| 27 |
+
self.next_samples = None
|
| 28 |
+
self.next_targets = None
|
| 29 |
+
return
|
| 30 |
+
# if record_stream() doesn't work, another option is to make sure device inputs are created
|
| 31 |
+
# on the main stream.
|
| 32 |
+
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
|
| 33 |
+
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
|
| 34 |
+
# Need to make sure the memory allocated for next_* is not still in use by the main stream
|
| 35 |
+
# at the time we start copying to next_*:
|
| 36 |
+
# self.stream.wait_stream(torch.cuda.current_stream())
|
| 37 |
+
with torch.cuda.stream(self.stream):
|
| 38 |
+
self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device)
|
| 39 |
+
# more code for the alternative if record_stream() doesn't work:
|
| 40 |
+
# copy_ will record the use of the pinned source tensor in this side stream.
|
| 41 |
+
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
| 42 |
+
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
| 43 |
+
# self.next_input = self.next_input_gpu
|
| 44 |
+
# self.next_target = self.next_target_gpu
|
| 45 |
+
|
| 46 |
+
# With Amp, it isn't necessary to manually convert data to half.
|
| 47 |
+
# if args.fp16:
|
| 48 |
+
# self.next_input = self.next_input.half()
|
| 49 |
+
# else:
|
| 50 |
+
|
| 51 |
+
def next(self):
|
| 52 |
+
if self.prefetch:
|
| 53 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 54 |
+
samples = self.next_samples
|
| 55 |
+
targets = self.next_targets
|
| 56 |
+
if samples is not None:
|
| 57 |
+
samples.record_stream(torch.cuda.current_stream())
|
| 58 |
+
if targets is not None:
|
| 59 |
+
for t in targets:
|
| 60 |
+
for k, v in t.items():
|
| 61 |
+
v.record_stream(torch.cuda.current_stream())
|
| 62 |
+
self.preload()
|
| 63 |
+
else:
|
| 64 |
+
try:
|
| 65 |
+
samples, targets = next(self.loader)
|
| 66 |
+
samples, targets = to_cuda(samples, targets, self.device)
|
| 67 |
+
except StopIteration:
|
| 68 |
+
samples = None
|
| 69 |
+
targets = None
|
| 70 |
+
return samples, targets
|
perception_models/apps/detection/DETA_pe/datasets/objects365.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
COCO dataset which returns image_id for evaluation.
|
| 12 |
+
|
| 13 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
| 14 |
+
"""
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import datasets.transforms as T
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.data
|
| 21 |
+
from pycocotools import mask as coco_mask
|
| 22 |
+
from util.misc import get_local_rank, get_local_size
|
| 23 |
+
|
| 24 |
+
from .coco import CocoDetection, make_coco_transforms, make_coco_transforms_lsj
|
| 25 |
+
from .torchvision_datasets import CocoDetection as TvCocoDetection
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build(image_set, args):
|
| 29 |
+
root = Path(args.coco_path)
|
| 30 |
+
assert root.exists(), f"provided Objects365 path {root} does not exist"
|
| 31 |
+
mode = "instances"
|
| 32 |
+
PATHS = {
|
| 33 |
+
"train": (
|
| 34 |
+
root / "train",
|
| 35 |
+
root / "annotations" / "zhiyuan_objv2_train_fixmiss.json",
|
| 36 |
+
),
|
| 37 |
+
"val": (root / "val", root / "annotations" / "zhiyuan_objv2_val.json"),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
img_folder, ann_file = PATHS[image_set]
|
| 41 |
+
if args.lsj:
|
| 42 |
+
coco_transform = make_coco_transforms_lsj(image_set, args.lsj_img_size)
|
| 43 |
+
else:
|
| 44 |
+
coco_transform = make_coco_transforms(image_set, args.bigger)
|
| 45 |
+
dataset = CocoDetection(
|
| 46 |
+
img_folder,
|
| 47 |
+
ann_file,
|
| 48 |
+
transforms=coco_transform,
|
| 49 |
+
return_masks=args.masks,
|
| 50 |
+
cache_mode=args.cache_mode,
|
| 51 |
+
local_rank=get_local_rank(),
|
| 52 |
+
local_size=get_local_size(),
|
| 53 |
+
)
|
| 54 |
+
return dataset
|
perception_models/apps/detection/DETA_pe/datasets/panoptic_eval.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import util.misc as utils
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from panopticapi.evaluation import pq_compute
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PanopticEvaluator(object):
|
| 22 |
+
def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
|
| 23 |
+
self.gt_json = ann_file
|
| 24 |
+
self.gt_folder = ann_folder
|
| 25 |
+
if utils.is_main_process():
|
| 26 |
+
if not os.path.exists(output_dir):
|
| 27 |
+
os.mkdir(output_dir)
|
| 28 |
+
self.output_dir = output_dir
|
| 29 |
+
self.predictions = []
|
| 30 |
+
|
| 31 |
+
def update(self, predictions):
|
| 32 |
+
for p in predictions:
|
| 33 |
+
with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f:
|
| 34 |
+
f.write(p.pop("png_string"))
|
| 35 |
+
|
| 36 |
+
self.predictions += predictions
|
| 37 |
+
|
| 38 |
+
def synchronize_between_processes(self):
|
| 39 |
+
all_predictions = utils.all_gather(self.predictions)
|
| 40 |
+
merged_predictions = []
|
| 41 |
+
for p in all_predictions:
|
| 42 |
+
merged_predictions += p
|
| 43 |
+
self.predictions = merged_predictions
|
| 44 |
+
|
| 45 |
+
def summarize(self):
|
| 46 |
+
if utils.is_main_process():
|
| 47 |
+
json_data = {"annotations": self.predictions}
|
| 48 |
+
predictions_json = os.path.join(self.output_dir, "predictions.json")
|
| 49 |
+
with open(predictions_json, "w") as f:
|
| 50 |
+
f.write(json.dumps(json_data))
|
| 51 |
+
return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir)
|
| 52 |
+
return None
|
perception_models/apps/detection/DETA_pe/datasets/samplers.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from codes in torch.utils.data.distributed
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
from fvcore.common.timer import Timer
|
| 18 |
+
from lvis import LVIS
|
| 19 |
+
from torch.utils.data.sampler import Sampler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_dataset_dicts(json_file):
|
| 23 |
+
timer = Timer()
|
| 24 |
+
lvis_api = LVIS(json_file)
|
| 25 |
+
if timer.seconds() > 1:
|
| 26 |
+
print("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
| 27 |
+
|
| 28 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
| 29 |
+
imgs = lvis_api.load_imgs(img_ids)
|
| 30 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
| 31 |
+
|
| 32 |
+
imgs_anns = list(zip(imgs, anns))
|
| 33 |
+
print(
|
| 34 |
+
"Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file)
|
| 35 |
+
)
|
| 36 |
+
dataset_dicts = []
|
| 37 |
+
|
| 38 |
+
for img_dict, anno_dict_list in imgs_anns:
|
| 39 |
+
record = {}
|
| 40 |
+
image_id = record["image_id"] = img_dict["id"]
|
| 41 |
+
objs = []
|
| 42 |
+
for anno in anno_dict_list:
|
| 43 |
+
# Check that the image_id in this annotation is the same as
|
| 44 |
+
# the image_id we're looking at.
|
| 45 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
| 46 |
+
assert anno["image_id"] == image_id
|
| 47 |
+
obj = {}
|
| 48 |
+
# Convert 1-indexed to 0-indexed
|
| 49 |
+
obj["category_id"] = anno["category_id"] - 1
|
| 50 |
+
|
| 51 |
+
objs.append(obj)
|
| 52 |
+
record["annotations"] = objs
|
| 53 |
+
dataset_dicts.append(record)
|
| 54 |
+
|
| 55 |
+
return dataset_dicts
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
|
| 59 |
+
# 1. For each category c, compute the fraction of images that contain it: f(c)
|
| 60 |
+
category_freq = defaultdict(int)
|
| 61 |
+
for dataset_dict in dataset_dicts: # For each image (without repeats)
|
| 62 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 63 |
+
for cat_id in cat_ids:
|
| 64 |
+
category_freq[cat_id] += 1
|
| 65 |
+
num_images = len(dataset_dicts)
|
| 66 |
+
for k, v in category_freq.items():
|
| 67 |
+
category_freq[k] = v / num_images
|
| 68 |
+
|
| 69 |
+
# 2. For each category c, compute the category-level repeat factor:
|
| 70 |
+
# r(c) = max(1, sqrt(t / f(c)))
|
| 71 |
+
category_rep = {
|
| 72 |
+
cat_id: max(
|
| 73 |
+
1.0,
|
| 74 |
+
(
|
| 75 |
+
math.sqrt(repeat_thresh / cat_freq)
|
| 76 |
+
if sqrt
|
| 77 |
+
else (repeat_thresh / cat_freq)
|
| 78 |
+
),
|
| 79 |
+
)
|
| 80 |
+
for cat_id, cat_freq in category_freq.items()
|
| 81 |
+
}
|
| 82 |
+
for cat_id in sorted(category_rep.keys()):
|
| 83 |
+
print(
|
| 84 |
+
f"Cat ID {cat_id}: freq={category_freq[cat_id]:.2f}, rep={category_rep[cat_id]:.2f}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 3. For each image I, compute the image-level repeat factor:
|
| 88 |
+
# r(I) = max_{c in I} r(c)
|
| 89 |
+
rep_factors = []
|
| 90 |
+
for dataset_dict in dataset_dicts:
|
| 91 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
| 92 |
+
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
|
| 93 |
+
rep_factors.append(rep_factor)
|
| 94 |
+
|
| 95 |
+
return torch.tensor(rep_factors, dtype=torch.float32)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RepeatFactorTrainingSampler(Sampler):
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
dataset,
|
| 102 |
+
num_replicas=None,
|
| 103 |
+
rank=None,
|
| 104 |
+
local_rank=None,
|
| 105 |
+
local_size=None,
|
| 106 |
+
shuffle=True,
|
| 107 |
+
):
|
| 108 |
+
if num_replicas is None:
|
| 109 |
+
if not dist.is_available():
|
| 110 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 111 |
+
num_replicas = dist.get_world_size()
|
| 112 |
+
if rank is None:
|
| 113 |
+
if not dist.is_available():
|
| 114 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 115 |
+
rank = dist.get_rank()
|
| 116 |
+
self.dataset = dataset
|
| 117 |
+
self.num_replicas = num_replicas
|
| 118 |
+
self.rank = rank
|
| 119 |
+
self.epoch = 0
|
| 120 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 121 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 122 |
+
self.shuffle = shuffle
|
| 123 |
+
|
| 124 |
+
json_file = (
|
| 125 |
+
"/checkpoint/onevision/peizesun/public_data/d2_data/lvis/lvis_v1_train.json"
|
| 126 |
+
)
|
| 127 |
+
dataset_dicts = load_dataset_dicts(json_file)
|
| 128 |
+
repeat_factors = repeat_factors_from_category_frequency(
|
| 129 |
+
dataset_dicts, repeat_thresh=0.001
|
| 130 |
+
)
|
| 131 |
+
# Split into whole number (_int_part) and fractional (_frac_part) parts.
|
| 132 |
+
self._int_part = torch.trunc(repeat_factors)
|
| 133 |
+
self._frac_part = repeat_factors - self._int_part
|
| 134 |
+
|
| 135 |
+
def _get_epoch_indices(self, generator):
|
| 136 |
+
"""
|
| 137 |
+
Create a list of dataset indices (with repeats) to use for one epoch.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
generator (torch.Generator): pseudo random number generator used for
|
| 141 |
+
stochastic rounding.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
torch.Tensor: list of dataset indices to use in one epoch. Each index
|
| 145 |
+
is repeated based on its calculated repeat factor.
|
| 146 |
+
"""
|
| 147 |
+
# Since repeat factors are fractional, we use stochastic rounding so
|
| 148 |
+
# that the target repeat factor is achieved in expectation over the
|
| 149 |
+
# course of training
|
| 150 |
+
rands = torch.rand(len(self._frac_part), generator=generator)
|
| 151 |
+
rep_factors = self._int_part + (rands < self._frac_part).float()
|
| 152 |
+
# Construct a list of indices in which we repeat images as specified
|
| 153 |
+
indices = []
|
| 154 |
+
for dataset_index, rep_factor in enumerate(rep_factors):
|
| 155 |
+
indices.extend([dataset_index] * int(rep_factor.item()))
|
| 156 |
+
return torch.tensor(indices, dtype=torch.int64)
|
| 157 |
+
|
| 158 |
+
def __iter__(self):
|
| 159 |
+
if self.shuffle:
|
| 160 |
+
g = torch.Generator()
|
| 161 |
+
g.manual_seed(self.epoch)
|
| 162 |
+
# Sample indices with repeats determined by stochastic rounding; each
|
| 163 |
+
# "epoch" may have a slightly different size due to the rounding.
|
| 164 |
+
rfs_indices = self._get_epoch_indices(g)
|
| 165 |
+
# deterministically shuffle based on epoch
|
| 166 |
+
randperm = torch.randperm(len(rfs_indices), generator=g)
|
| 167 |
+
indices = rfs_indices[randperm].tolist()
|
| 168 |
+
else:
|
| 169 |
+
g = torch.Generator()
|
| 170 |
+
g.manual_seed(0)
|
| 171 |
+
# Sample indices with repeats determined by stochastic rounding; each
|
| 172 |
+
# "epoch" may have a slightly different size due to the rounding.
|
| 173 |
+
rfs_indices = self._get_epoch_indices(g)
|
| 174 |
+
indices = rfs_indices.tolist()
|
| 175 |
+
|
| 176 |
+
# add extra samples to make it evenly divisible
|
| 177 |
+
if self.total_size > len(indices):
|
| 178 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 179 |
+
assert len(indices) == self.total_size
|
| 180 |
+
# subsample
|
| 181 |
+
offset = self.num_samples * self.rank
|
| 182 |
+
indices = indices[offset : offset + self.num_samples]
|
| 183 |
+
assert len(indices) == self.num_samples
|
| 184 |
+
|
| 185 |
+
return iter(indices)
|
| 186 |
+
else:
|
| 187 |
+
self.num_samples = int(math.ceil(len(indices) * 1.0 / self.num_replicas))
|
| 188 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 189 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 190 |
+
assert len(indices) == self.total_size
|
| 191 |
+
# subsample
|
| 192 |
+
offset = self.num_samples * self.rank
|
| 193 |
+
indices = indices[offset : offset + self.num_samples]
|
| 194 |
+
assert len(indices) == self.num_samples
|
| 195 |
+
|
| 196 |
+
return iter(indices)
|
| 197 |
+
|
| 198 |
+
def __len__(self):
|
| 199 |
+
return self.num_samples
|
| 200 |
+
|
| 201 |
+
def set_epoch(self, epoch):
|
| 202 |
+
self.epoch = epoch
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DistributedSampler(Sampler):
|
| 206 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 207 |
+
It is especially useful in conjunction with
|
| 208 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
| 209 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
| 210 |
+
and load a subset of the original dataset that is exclusive to it.
|
| 211 |
+
.. note::
|
| 212 |
+
Dataset is assumed to be of constant size.
|
| 213 |
+
Arguments:
|
| 214 |
+
dataset: Dataset used for sampling.
|
| 215 |
+
num_replicas (optional): Number of processes participating in
|
| 216 |
+
distributed training.
|
| 217 |
+
rank (optional): Rank of the current process within num_replicas.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
dataset,
|
| 223 |
+
num_replicas=None,
|
| 224 |
+
rank=None,
|
| 225 |
+
local_rank=None,
|
| 226 |
+
local_size=None,
|
| 227 |
+
shuffle=True,
|
| 228 |
+
):
|
| 229 |
+
if num_replicas is None:
|
| 230 |
+
if not dist.is_available():
|
| 231 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 232 |
+
num_replicas = dist.get_world_size()
|
| 233 |
+
if rank is None:
|
| 234 |
+
if not dist.is_available():
|
| 235 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 236 |
+
rank = dist.get_rank()
|
| 237 |
+
self.dataset = dataset
|
| 238 |
+
self.num_replicas = num_replicas
|
| 239 |
+
self.rank = rank
|
| 240 |
+
self.epoch = 0
|
| 241 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 242 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 243 |
+
self.shuffle = shuffle
|
| 244 |
+
|
| 245 |
+
def __iter__(self):
|
| 246 |
+
if self.shuffle:
|
| 247 |
+
# deterministically shuffle based on epoch
|
| 248 |
+
g = torch.Generator()
|
| 249 |
+
g.manual_seed(self.epoch)
|
| 250 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 251 |
+
else:
|
| 252 |
+
indices = torch.arange(len(self.dataset)).tolist()
|
| 253 |
+
|
| 254 |
+
# add extra samples to make it evenly divisible
|
| 255 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 256 |
+
assert len(indices) == self.total_size
|
| 257 |
+
|
| 258 |
+
# subsample
|
| 259 |
+
offset = self.num_samples * self.rank
|
| 260 |
+
indices = indices[offset : offset + self.num_samples]
|
| 261 |
+
assert len(indices) == self.num_samples
|
| 262 |
+
|
| 263 |
+
return iter(indices)
|
| 264 |
+
|
| 265 |
+
def __len__(self):
|
| 266 |
+
return self.num_samples
|
| 267 |
+
|
| 268 |
+
def set_epoch(self, epoch):
|
| 269 |
+
self.epoch = epoch
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class NodeDistributedSampler(Sampler):
|
| 273 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 274 |
+
It is especially useful in conjunction with
|
| 275 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
| 276 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
| 277 |
+
and load a subset of the original dataset that is exclusive to it.
|
| 278 |
+
.. note::
|
| 279 |
+
Dataset is assumed to be of constant size.
|
| 280 |
+
Arguments:
|
| 281 |
+
dataset: Dataset used for sampling.
|
| 282 |
+
num_replicas (optional): Number of processes participating in
|
| 283 |
+
distributed training.
|
| 284 |
+
rank (optional): Rank of the current process within num_replicas.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
dataset,
|
| 290 |
+
num_replicas=None,
|
| 291 |
+
rank=None,
|
| 292 |
+
local_rank=None,
|
| 293 |
+
local_size=None,
|
| 294 |
+
shuffle=True,
|
| 295 |
+
):
|
| 296 |
+
if num_replicas is None:
|
| 297 |
+
if not dist.is_available():
|
| 298 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 299 |
+
num_replicas = dist.get_world_size()
|
| 300 |
+
if rank is None:
|
| 301 |
+
if not dist.is_available():
|
| 302 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 303 |
+
rank = dist.get_rank()
|
| 304 |
+
if local_rank is None:
|
| 305 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 306 |
+
if local_size is None:
|
| 307 |
+
local_size = int(os.environ.get("LOCAL_SIZE", 1))
|
| 308 |
+
self.dataset = dataset
|
| 309 |
+
self.shuffle = shuffle
|
| 310 |
+
self.num_replicas = num_replicas
|
| 311 |
+
self.num_parts = local_size
|
| 312 |
+
self.rank = rank
|
| 313 |
+
self.local_rank = local_rank
|
| 314 |
+
self.epoch = 0
|
| 315 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 316 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 317 |
+
|
| 318 |
+
self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
|
| 319 |
+
|
| 320 |
+
def __iter__(self):
|
| 321 |
+
if self.shuffle:
|
| 322 |
+
# deterministically shuffle based on epoch
|
| 323 |
+
g = torch.Generator()
|
| 324 |
+
g.manual_seed(self.epoch)
|
| 325 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 326 |
+
else:
|
| 327 |
+
indices = torch.arange(len(self.dataset)).tolist()
|
| 328 |
+
indices = [i for i in indices if i % self.num_parts == self.local_rank]
|
| 329 |
+
|
| 330 |
+
# add extra samples to make it evenly divisible
|
| 331 |
+
indices += indices[: (self.total_size_parts - len(indices))]
|
| 332 |
+
assert len(indices) == self.total_size_parts
|
| 333 |
+
|
| 334 |
+
# subsample
|
| 335 |
+
indices = indices[
|
| 336 |
+
self.rank
|
| 337 |
+
// self.num_parts : self.total_size_parts : self.num_replicas
|
| 338 |
+
// self.num_parts
|
| 339 |
+
]
|
| 340 |
+
assert len(indices) == self.num_samples
|
| 341 |
+
|
| 342 |
+
return iter(indices)
|
| 343 |
+
|
| 344 |
+
def __len__(self):
|
| 345 |
+
return self.num_samples
|
| 346 |
+
|
| 347 |
+
def set_epoch(self, epoch):
|
| 348 |
+
self.epoch = epoch
|
perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .coco import CocoDetection
|
perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/coco.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from torchvision
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
Copy-Paste from torchvision, but add utility of caching images on memory
|
| 11 |
+
"""
|
| 12 |
+
from torchvision.datasets.vision import VisionDataset
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import os
|
| 15 |
+
import os.path
|
| 16 |
+
import tqdm
|
| 17 |
+
from io import BytesIO
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CocoDetection(VisionDataset):
|
| 21 |
+
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
|
| 22 |
+
Args:
|
| 23 |
+
root (string): Root directory where images are downloaded to.
|
| 24 |
+
annFile (string): Path to json annotation file.
|
| 25 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
| 26 |
+
and returns a transformed version. E.g, ``transforms.ToTensor``
|
| 27 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 28 |
+
target and transforms it.
|
| 29 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
| 30 |
+
and returns a transformed version.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None,
|
| 34 |
+
cache_mode=False, local_rank=0, local_size=1):
|
| 35 |
+
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
|
| 36 |
+
from pycocotools.coco import COCO
|
| 37 |
+
self.coco = COCO(annFile)
|
| 38 |
+
self.ids = list(sorted(self.coco.imgs.keys()))
|
| 39 |
+
self.cache_mode = cache_mode
|
| 40 |
+
self.local_rank = local_rank
|
| 41 |
+
self.local_size = local_size
|
| 42 |
+
if cache_mode:
|
| 43 |
+
self.cache = {}
|
| 44 |
+
self.cache_images()
|
| 45 |
+
|
| 46 |
+
def cache_images(self):
|
| 47 |
+
self.cache = {}
|
| 48 |
+
for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids):
|
| 49 |
+
if index % self.local_size != self.local_rank:
|
| 50 |
+
continue
|
| 51 |
+
path = self.coco.loadImgs(img_id)[0]['file_name']
|
| 52 |
+
with open(os.path.join(self.root, path), 'rb') as f:
|
| 53 |
+
self.cache[path] = f.read()
|
| 54 |
+
|
| 55 |
+
def get_image(self, path):
|
| 56 |
+
if self.cache_mode:
|
| 57 |
+
if path not in self.cache.keys():
|
| 58 |
+
with open(os.path.join(self.root, path), 'rb') as f:
|
| 59 |
+
self.cache[path] = f.read()
|
| 60 |
+
return Image.open(BytesIO(self.cache[path])).convert('RGB')
|
| 61 |
+
return Image.open(os.path.join(self.root, path)).convert('RGB')
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, index):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
index (int): Index
|
| 67 |
+
Returns:
|
| 68 |
+
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
| 69 |
+
"""
|
| 70 |
+
coco = self.coco
|
| 71 |
+
img_id = self.ids[index]
|
| 72 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
| 73 |
+
target = coco.loadAnns(ann_ids)
|
| 74 |
+
|
| 75 |
+
path = coco.loadImgs(img_id)[0]['file_name']
|
| 76 |
+
|
| 77 |
+
img = self.get_image(path)
|
| 78 |
+
if self.transforms is not None:
|
| 79 |
+
img, target = self.transforms(img, target)
|
| 80 |
+
|
| 81 |
+
return img, target
|
| 82 |
+
|
| 83 |
+
def __len__(self):
|
| 84 |
+
return len(self.ids)
|
perception_models/apps/detection/DETA_pe/datasets/transforms.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Transforms and data augmentation for both image + bbox.
|
| 12 |
+
"""
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
import PIL
|
| 16 |
+
import torch
|
| 17 |
+
import torchvision.transforms as T
|
| 18 |
+
import torchvision.transforms.functional as F
|
| 19 |
+
|
| 20 |
+
from util.box_ops import box_xyxy_to_cxcywh
|
| 21 |
+
from util.misc import interpolate
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def crop(image, target, region):
|
| 25 |
+
cropped_image = F.crop(image, *region)
|
| 26 |
+
|
| 27 |
+
target = target.copy()
|
| 28 |
+
i, j, h, w = region
|
| 29 |
+
|
| 30 |
+
# should we do something wrt the original size?
|
| 31 |
+
target["size"] = torch.tensor([h, w])
|
| 32 |
+
|
| 33 |
+
fields = ["labels", "area", "iscrowd"]
|
| 34 |
+
|
| 35 |
+
if "boxes" in target:
|
| 36 |
+
boxes = target["boxes"]
|
| 37 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 38 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 39 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 40 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 41 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 42 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 43 |
+
target["area"] = area
|
| 44 |
+
fields.append("boxes")
|
| 45 |
+
|
| 46 |
+
if "masks" in target:
|
| 47 |
+
# FIXME should we update the area here if there are no boxes?
|
| 48 |
+
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
| 49 |
+
fields.append("masks")
|
| 50 |
+
|
| 51 |
+
# remove elements for which the boxes or masks that have zero area
|
| 52 |
+
if "boxes" in target or "masks" in target:
|
| 53 |
+
# favor boxes selection when defining which elements to keep
|
| 54 |
+
# this is compatible with previous implementation
|
| 55 |
+
if "boxes" in target:
|
| 56 |
+
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
| 57 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 58 |
+
else:
|
| 59 |
+
keep = target["masks"].flatten(1).any(1)
|
| 60 |
+
|
| 61 |
+
for field in fields:
|
| 62 |
+
target[field] = target[field][keep]
|
| 63 |
+
|
| 64 |
+
return cropped_image, target
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def hflip(image, target):
|
| 68 |
+
flipped_image = F.hflip(image)
|
| 69 |
+
|
| 70 |
+
w, h = image.size
|
| 71 |
+
|
| 72 |
+
target = target.copy()
|
| 73 |
+
if "boxes" in target:
|
| 74 |
+
boxes = target["boxes"]
|
| 75 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
|
| 76 |
+
[-1, 1, -1, 1]
|
| 77 |
+
) + torch.as_tensor([w, 0, w, 0])
|
| 78 |
+
target["boxes"] = boxes
|
| 79 |
+
|
| 80 |
+
if "masks" in target:
|
| 81 |
+
target["masks"] = target["masks"].flip(-1)
|
| 82 |
+
|
| 83 |
+
return flipped_image, target
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def resize(image, target, size, max_size=None):
|
| 87 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 88 |
+
|
| 89 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 90 |
+
w, h = image_size
|
| 91 |
+
if max_size is not None:
|
| 92 |
+
min_original_size = float(min((w, h)))
|
| 93 |
+
max_original_size = float(max((w, h)))
|
| 94 |
+
if max_original_size / min_original_size * size > max_size:
|
| 95 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
| 96 |
+
|
| 97 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 98 |
+
return (h, w)
|
| 99 |
+
if w < h:
|
| 100 |
+
ow = size
|
| 101 |
+
oh = int(size * h / w)
|
| 102 |
+
else:
|
| 103 |
+
oh = size
|
| 104 |
+
ow = int(size * w / h)
|
| 105 |
+
return (oh, ow)
|
| 106 |
+
|
| 107 |
+
def get_size(image_size, size, max_size=None):
|
| 108 |
+
if isinstance(size, (list, tuple)):
|
| 109 |
+
return size[::-1]
|
| 110 |
+
else:
|
| 111 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 112 |
+
|
| 113 |
+
size = get_size(image.size, size, max_size)
|
| 114 |
+
rescaled_image = F.resize(image, size)
|
| 115 |
+
|
| 116 |
+
if target is None:
|
| 117 |
+
return rescaled_image, None
|
| 118 |
+
|
| 119 |
+
ratios = tuple(
|
| 120 |
+
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
|
| 121 |
+
)
|
| 122 |
+
ratio_width, ratio_height = ratios
|
| 123 |
+
|
| 124 |
+
target = target.copy()
|
| 125 |
+
if "boxes" in target:
|
| 126 |
+
boxes = target["boxes"]
|
| 127 |
+
scaled_boxes = boxes * torch.as_tensor(
|
| 128 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
| 129 |
+
)
|
| 130 |
+
target["boxes"] = scaled_boxes
|
| 131 |
+
|
| 132 |
+
if "area" in target:
|
| 133 |
+
area = target["area"]
|
| 134 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 135 |
+
target["area"] = scaled_area
|
| 136 |
+
|
| 137 |
+
h, w = size
|
| 138 |
+
target["size"] = torch.tensor([h, w])
|
| 139 |
+
|
| 140 |
+
if "masks" in target:
|
| 141 |
+
target["masks"] = (
|
| 142 |
+
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
|
| 143 |
+
> 0.5
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return rescaled_image, target
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def pad(image, target, padding):
|
| 150 |
+
# assumes that we only pad on the bottom right corners
|
| 151 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
| 152 |
+
if target is None:
|
| 153 |
+
return padded_image, None
|
| 154 |
+
target = target.copy()
|
| 155 |
+
# should we do something wrt the original size?
|
| 156 |
+
target["size"] = torch.tensor(padded_image[::-1])
|
| 157 |
+
if "masks" in target:
|
| 158 |
+
target["masks"] = torch.nn.functional.pad(
|
| 159 |
+
target["masks"], (0, padding[0], 0, padding[1])
|
| 160 |
+
)
|
| 161 |
+
return padded_image, target
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class RandomCrop(object):
|
| 165 |
+
def __init__(self, size):
|
| 166 |
+
self.size = size
|
| 167 |
+
|
| 168 |
+
def __call__(self, img, target):
|
| 169 |
+
region = T.RandomCrop.get_params(img, self.size)
|
| 170 |
+
return crop(img, target, region)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RandomSizeCrop(object):
|
| 174 |
+
def __init__(self, min_size: int, max_size: int):
|
| 175 |
+
self.min_size = min_size
|
| 176 |
+
self.max_size = max_size
|
| 177 |
+
|
| 178 |
+
def __call__(self, img: PIL.Image.Image, target: dict):
|
| 179 |
+
w = random.randint(self.min_size, min(img.width, self.max_size))
|
| 180 |
+
h = random.randint(self.min_size, min(img.height, self.max_size))
|
| 181 |
+
region = T.RandomCrop.get_params(img, [h, w])
|
| 182 |
+
return crop(img, target, region)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class CenterCrop(object):
|
| 186 |
+
def __init__(self, size):
|
| 187 |
+
self.size = size
|
| 188 |
+
|
| 189 |
+
def __call__(self, img, target):
|
| 190 |
+
image_width, image_height = img.size
|
| 191 |
+
crop_height, crop_width = self.size
|
| 192 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
| 193 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
| 194 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class RandomHorizontalFlip(object):
|
| 198 |
+
def __init__(self, p=0.5):
|
| 199 |
+
self.p = p
|
| 200 |
+
|
| 201 |
+
def __call__(self, img, target):
|
| 202 |
+
if random.random() < self.p:
|
| 203 |
+
return hflip(img, target)
|
| 204 |
+
return img, target
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class RandomResize(object):
|
| 208 |
+
def __init__(self, sizes, max_size=None):
|
| 209 |
+
assert isinstance(sizes, (list, tuple))
|
| 210 |
+
self.sizes = sizes
|
| 211 |
+
self.max_size = max_size
|
| 212 |
+
|
| 213 |
+
def __call__(self, img, target=None):
|
| 214 |
+
size = random.choice(self.sizes)
|
| 215 |
+
return resize(img, target, size, self.max_size)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class RandomPad(object):
|
| 219 |
+
def __init__(self, max_pad):
|
| 220 |
+
self.max_pad = max_pad
|
| 221 |
+
|
| 222 |
+
def __call__(self, img, target):
|
| 223 |
+
pad_x = random.randint(0, self.max_pad)
|
| 224 |
+
pad_y = random.randint(0, self.max_pad)
|
| 225 |
+
return pad(img, target, (pad_x, pad_y))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class RandomSelect(object):
|
| 229 |
+
"""
|
| 230 |
+
Randomly selects between transforms1 and transforms2,
|
| 231 |
+
with probability p for transforms1 and (1 - p) for transforms2
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, transforms1, transforms2, p=0.5):
|
| 235 |
+
self.transforms1 = transforms1
|
| 236 |
+
self.transforms2 = transforms2
|
| 237 |
+
self.p = p
|
| 238 |
+
|
| 239 |
+
def __call__(self, img, target):
|
| 240 |
+
if random.random() < self.p:
|
| 241 |
+
return self.transforms1(img, target)
|
| 242 |
+
return self.transforms2(img, target)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class ToTensor(object):
|
| 246 |
+
def __call__(self, img, target):
|
| 247 |
+
return F.to_tensor(img), target
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class RandomErasingP05(object):
|
| 251 |
+
def __init__(self):
|
| 252 |
+
self.eraser = T.Compose(
|
| 253 |
+
[
|
| 254 |
+
T.ToTensor(),
|
| 255 |
+
T.RandomErasing(
|
| 256 |
+
p=0.5, scale=(0.02, 0.2), ratio=(0.1, 6), value="random"
|
| 257 |
+
),
|
| 258 |
+
T.ToPILImage(),
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def __call__(self, img, target):
|
| 263 |
+
return self.eraser(img), target
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class RandomErasing(object):
|
| 267 |
+
def __init__(self, *args, **kwargs):
|
| 268 |
+
self.eraser = T.RandomErasing(*args, **kwargs)
|
| 269 |
+
|
| 270 |
+
def __call__(self, img, target):
|
| 271 |
+
return self.eraser(img), target
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class ColorJitter(object):
|
| 275 |
+
def __init__(self, jitter=(0.2, 0.2, 0.2, 0.1), p=0.5):
|
| 276 |
+
self.color_jitter = T.ColorJitter(*jitter)
|
| 277 |
+
self.p = p
|
| 278 |
+
|
| 279 |
+
def __call__(self, img, target):
|
| 280 |
+
if random.random() < self.p:
|
| 281 |
+
return self.color_jitter(img), target
|
| 282 |
+
return img, target
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class RandomGrayscale(object):
|
| 286 |
+
def __init__(self, p=0.5):
|
| 287 |
+
self.random_gray = T.RandomGrayscale(p=p)
|
| 288 |
+
|
| 289 |
+
def __call__(self, img, target):
|
| 290 |
+
return self.random_gray(img), target
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class Normalize(object):
|
| 294 |
+
def __init__(self, mean, std):
|
| 295 |
+
self.mean = mean
|
| 296 |
+
self.std = std
|
| 297 |
+
|
| 298 |
+
def __call__(self, image, target=None):
|
| 299 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 300 |
+
if target is None:
|
| 301 |
+
return image, None
|
| 302 |
+
target = target.copy()
|
| 303 |
+
h, w = image.shape[-2:]
|
| 304 |
+
if "boxes" in target:
|
| 305 |
+
boxes = target["boxes"]
|
| 306 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
| 307 |
+
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 308 |
+
target["boxes"] = boxes
|
| 309 |
+
return image, target
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class Compose(object):
|
| 313 |
+
def __init__(self, transforms):
|
| 314 |
+
self.transforms = transforms
|
| 315 |
+
|
| 316 |
+
def __call__(self, image, target):
|
| 317 |
+
for t in self.transforms:
|
| 318 |
+
image, target = t(image, target)
|
| 319 |
+
return image, target
|
| 320 |
+
|
| 321 |
+
def __repr__(self):
|
| 322 |
+
format_string = self.__class__.__name__ + "("
|
| 323 |
+
for t in self.transforms:
|
| 324 |
+
format_string += "\n"
|
| 325 |
+
format_string += " {0}".format(t)
|
| 326 |
+
format_string += "\n)"
|
| 327 |
+
return format_string
|
perception_models/apps/detection/DETA_pe/engine.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Train and eval functions used in main.py
|
| 12 |
+
"""
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Iterable
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import util.misc as utils
|
| 20 |
+
from datasets.coco_eval import CocoEvaluator, convert_to_xywh
|
| 21 |
+
from datasets.data_prefetcher import data_prefetcher
|
| 22 |
+
from datasets.panoptic_eval import PanopticEvaluator
|
| 23 |
+
from util.ema import requires_grad, update_ema
|
| 24 |
+
from util.misc import NestedTensor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def train_one_epoch(
|
| 28 |
+
model: torch.nn.Module,
|
| 29 |
+
criterion: torch.nn.Module,
|
| 30 |
+
data_loader: Iterable,
|
| 31 |
+
optimizer: torch.optim.Optimizer,
|
| 32 |
+
device: torch.device,
|
| 33 |
+
epoch: int,
|
| 34 |
+
max_norm: float = 0,
|
| 35 |
+
ema: torch.nn.Module = None,
|
| 36 |
+
ema_decay: float = 0.999,
|
| 37 |
+
):
|
| 38 |
+
model.train()
|
| 39 |
+
criterion.train()
|
| 40 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 41 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 42 |
+
metric_logger.add_meter(
|
| 43 |
+
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 44 |
+
)
|
| 45 |
+
metric_logger.add_meter(
|
| 46 |
+
"grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 47 |
+
)
|
| 48 |
+
header = "Epoch: [{}]".format(epoch)
|
| 49 |
+
print_freq = 10
|
| 50 |
+
|
| 51 |
+
prefetcher = data_prefetcher(data_loader, device, prefetch=True)
|
| 52 |
+
samples, targets = prefetcher.next()
|
| 53 |
+
|
| 54 |
+
# for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 55 |
+
for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
|
| 56 |
+
outputs = model(samples)
|
| 57 |
+
loss_dict = criterion(outputs, targets)
|
| 58 |
+
weight_dict = criterion.weight_dict
|
| 59 |
+
losses = sum(
|
| 60 |
+
loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# reduce losses over all GPUs for logging purposes
|
| 64 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 65 |
+
loss_dict_reduced_unscaled = {
|
| 66 |
+
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
| 67 |
+
}
|
| 68 |
+
loss_dict_reduced_scaled = {
|
| 69 |
+
k: v * weight_dict[k]
|
| 70 |
+
for k, v in loss_dict_reduced.items()
|
| 71 |
+
if k in weight_dict
|
| 72 |
+
}
|
| 73 |
+
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
|
| 74 |
+
|
| 75 |
+
loss_value = losses_reduced_scaled.item()
|
| 76 |
+
|
| 77 |
+
if not math.isfinite(loss_value):
|
| 78 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 79 |
+
print(loss_dict_reduced)
|
| 80 |
+
sys.exit(1)
|
| 81 |
+
|
| 82 |
+
optimizer.zero_grad()
|
| 83 |
+
losses.backward()
|
| 84 |
+
if max_norm > 0:
|
| 85 |
+
grad_total_norm = torch.nn.utils.clip_grad_norm_(
|
| 86 |
+
model.parameters(), max_norm
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
|
| 90 |
+
optimizer.step()
|
| 91 |
+
|
| 92 |
+
if ema is not None:
|
| 93 |
+
update_ema(ema, model.module, ema_decay)
|
| 94 |
+
# torch.cuda.empty_cache()
|
| 95 |
+
|
| 96 |
+
metric_logger.update(
|
| 97 |
+
loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
|
| 98 |
+
)
|
| 99 |
+
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
| 100 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 101 |
+
metric_logger.update(grad_norm=grad_total_norm)
|
| 102 |
+
|
| 103 |
+
samples, targets = prefetcher.next()
|
| 104 |
+
# gather the stats from all processes
|
| 105 |
+
metric_logger.synchronize_between_processes()
|
| 106 |
+
print("Averaged stats:", metric_logger)
|
| 107 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def evaluate(
|
| 112 |
+
model_no_ema,
|
| 113 |
+
criterion,
|
| 114 |
+
postprocessors,
|
| 115 |
+
data_loader,
|
| 116 |
+
base_ds,
|
| 117 |
+
device,
|
| 118 |
+
output_dir,
|
| 119 |
+
test_hflip_aug,
|
| 120 |
+
tta,
|
| 121 |
+
soft_nms,
|
| 122 |
+
ema=None,
|
| 123 |
+
save_result=False,
|
| 124 |
+
save_result_dir="",
|
| 125 |
+
soft_nms_method="quad",
|
| 126 |
+
nms_thresh=0.7,
|
| 127 |
+
quad_scale=0.5,
|
| 128 |
+
lsj_img_size=1824,
|
| 129 |
+
):
|
| 130 |
+
model = model_no_ema if ema is None else ema
|
| 131 |
+
model.eval()
|
| 132 |
+
criterion.eval()
|
| 133 |
+
|
| 134 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 135 |
+
metric_logger.add_meter(
|
| 136 |
+
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 137 |
+
)
|
| 138 |
+
header = "Test:"
|
| 139 |
+
|
| 140 |
+
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
|
| 141 |
+
coco_evaluator = CocoEvaluator(base_ds, iou_types)
|
| 142 |
+
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
|
| 143 |
+
|
| 144 |
+
panoptic_evaluator = None
|
| 145 |
+
if "panoptic" in postprocessors.keys():
|
| 146 |
+
panoptic_evaluator = PanopticEvaluator(
|
| 147 |
+
data_loader.dataset.ann_file,
|
| 148 |
+
data_loader.dataset.ann_folder,
|
| 149 |
+
output_dir=os.path.join(output_dir, "panoptic_eval"),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
prediction_list = []
|
| 153 |
+
for samples, targets in metric_logger.log_every(data_loader, 10, header):
|
| 154 |
+
samples = samples.to(device)
|
| 155 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 156 |
+
|
| 157 |
+
if test_hflip_aug:
|
| 158 |
+
assert (
|
| 159 |
+
samples.tensors.shape[0] == 1
|
| 160 |
+
), "test_hflip_aug only supports batch size 1"
|
| 161 |
+
assert (
|
| 162 |
+
samples.tensors.shape[1] == 6
|
| 163 |
+
), "test_hflip_aug requires two images in a batch"
|
| 164 |
+
first_samples = NestedTensor(samples.tensors[:, :3], samples.mask)
|
| 165 |
+
outputs = model(first_samples)
|
| 166 |
+
flipped_samples = NestedTensor(samples.tensors[:, 3:], samples.mask)
|
| 167 |
+
flipped_outputs = model(flipped_samples)
|
| 168 |
+
else:
|
| 169 |
+
outputs = model(samples)
|
| 170 |
+
loss_dict = criterion(outputs, targets)
|
| 171 |
+
weight_dict = criterion.weight_dict
|
| 172 |
+
|
| 173 |
+
# reduce losses over all GPUs for logging purposes
|
| 174 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 175 |
+
loss_dict_reduced_scaled = {
|
| 176 |
+
k: v * weight_dict[k]
|
| 177 |
+
for k, v in loss_dict_reduced.items()
|
| 178 |
+
if k in weight_dict
|
| 179 |
+
}
|
| 180 |
+
loss_dict_reduced_unscaled = {
|
| 181 |
+
f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
|
| 182 |
+
}
|
| 183 |
+
metric_logger.update(
|
| 184 |
+
loss=sum(loss_dict_reduced_scaled.values()),
|
| 185 |
+
**loss_dict_reduced_scaled,
|
| 186 |
+
**loss_dict_reduced_unscaled,
|
| 187 |
+
)
|
| 188 |
+
metric_logger.update(class_error=loss_dict_reduced["class_error"])
|
| 189 |
+
|
| 190 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 191 |
+
if test_hflip_aug:
|
| 192 |
+
new_outputs = {}
|
| 193 |
+
pred_logits = outputs["pred_logits"]
|
| 194 |
+
pred_boxes = outputs["pred_boxes"]
|
| 195 |
+
|
| 196 |
+
flipped_pred_logits = flipped_outputs["pred_logits"]
|
| 197 |
+
flipped_pred_boxes = flipped_outputs["pred_boxes"]
|
| 198 |
+
|
| 199 |
+
reflipped_pred_boxes = flipped_pred_boxes[
|
| 200 |
+
:, :, [0, 1, 2, 3]
|
| 201 |
+
] * torch.as_tensor([-1, 1, 1, 1]).to(
|
| 202 |
+
flipped_pred_boxes.device
|
| 203 |
+
) + torch.as_tensor(
|
| 204 |
+
[1, 0, 0, 0]
|
| 205 |
+
).to(
|
| 206 |
+
flipped_pred_boxes.device
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
new_pred_logits = torch.cat([pred_logits, flipped_pred_logits], dim=1)
|
| 210 |
+
new_pred_boxes = torch.cat([pred_boxes, reflipped_pred_boxes], dim=1)
|
| 211 |
+
|
| 212 |
+
new_outputs["pred_logits"] = new_pred_logits
|
| 213 |
+
new_outputs["pred_boxes"] = new_pred_boxes
|
| 214 |
+
results = postprocessors["bbox"](
|
| 215 |
+
new_outputs,
|
| 216 |
+
orig_target_sizes,
|
| 217 |
+
soft_nms=soft_nms,
|
| 218 |
+
method=soft_nms_method,
|
| 219 |
+
nms_thresh=nms_thresh,
|
| 220 |
+
quad_scale=quad_scale,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
results = postprocessors["bbox"](
|
| 224 |
+
outputs,
|
| 225 |
+
orig_target_sizes,
|
| 226 |
+
soft_nms=soft_nms,
|
| 227 |
+
method=soft_nms_method,
|
| 228 |
+
nms_thresh=nms_thresh,
|
| 229 |
+
quad_scale=quad_scale,
|
| 230 |
+
)
|
| 231 |
+
if "segm" in postprocessors.keys():
|
| 232 |
+
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
|
| 233 |
+
results = postprocessors["segm"](
|
| 234 |
+
results, outputs, orig_target_sizes, target_sizes
|
| 235 |
+
)
|
| 236 |
+
res = {
|
| 237 |
+
target["image_id"].item(): output
|
| 238 |
+
for target, output in zip(targets, results)
|
| 239 |
+
}
|
| 240 |
+
if coco_evaluator is not None:
|
| 241 |
+
coco_evaluator.update(res)
|
| 242 |
+
|
| 243 |
+
if panoptic_evaluator is not None:
|
| 244 |
+
res_pano = postprocessors["panoptic"](
|
| 245 |
+
outputs, target_sizes, orig_target_sizes
|
| 246 |
+
)
|
| 247 |
+
for i, target in enumerate(targets):
|
| 248 |
+
image_id = target["image_id"].item()
|
| 249 |
+
file_name = f"{image_id:012d}.png"
|
| 250 |
+
res_pano[i]["image_id"] = image_id
|
| 251 |
+
res_pano[i]["file_name"] = file_name
|
| 252 |
+
|
| 253 |
+
panoptic_evaluator.update(res_pano)
|
| 254 |
+
|
| 255 |
+
for target, output in zip(targets, results):
|
| 256 |
+
res_cpu = {
|
| 257 |
+
target["image_id"].item(): {
|
| 258 |
+
"boxes": output["boxes"].cpu(),
|
| 259 |
+
"labels": output["labels"].cpu(),
|
| 260 |
+
"scores": output["scores"].cpu(),
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
prediction_list.append(res_cpu)
|
| 264 |
+
|
| 265 |
+
# gather the stats from all processes
|
| 266 |
+
metric_logger.synchronize_between_processes()
|
| 267 |
+
print("Averaged stats:", metric_logger)
|
| 268 |
+
|
| 269 |
+
if save_result:
|
| 270 |
+
|
| 271 |
+
from torch import distributed as dist
|
| 272 |
+
|
| 273 |
+
os.makedirs(save_result_dir, exist_ok=True)
|
| 274 |
+
rank = dist.get_rank()
|
| 275 |
+
torch.save(
|
| 276 |
+
prediction_list,
|
| 277 |
+
os.path.join(save_result_dir, f"val2017_prediction_{rank}.pth"),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if coco_evaluator is not None:
|
| 281 |
+
coco_evaluator.synchronize_between_processes()
|
| 282 |
+
if panoptic_evaluator is not None:
|
| 283 |
+
panoptic_evaluator.synchronize_between_processes()
|
| 284 |
+
|
| 285 |
+
# accumulate predictions from all images
|
| 286 |
+
if coco_evaluator is not None:
|
| 287 |
+
coco_evaluator.accumulate()
|
| 288 |
+
coco_evaluator.summarize()
|
| 289 |
+
panoptic_res = None
|
| 290 |
+
if panoptic_evaluator is not None:
|
| 291 |
+
panoptic_res = panoptic_evaluator.summarize()
|
| 292 |
+
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 293 |
+
if coco_evaluator is not None:
|
| 294 |
+
if "bbox" in postprocessors.keys():
|
| 295 |
+
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
|
| 296 |
+
if "segm" in postprocessors.keys():
|
| 297 |
+
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
|
| 298 |
+
if panoptic_res is not None:
|
| 299 |
+
stats["PQ_all"] = panoptic_res["All"]
|
| 300 |
+
stats["PQ_th"] = panoptic_res["Things"]
|
| 301 |
+
stats["PQ_st"] = panoptic_res["Stuff"]
|
| 302 |
+
return stats, coco_evaluator
|
| 303 |
+
|
perception_models/apps/detection/DETA_pe/engine_tta.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Train and eval functions used in main.py
|
| 12 |
+
"""
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Iterable
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import util.misc as utils
|
| 20 |
+
from datasets.coco_eval import CocoEvaluator, convert_to_xywh
|
| 21 |
+
from datasets.data_prefetcher import data_prefetcher
|
| 22 |
+
from datasets.panoptic_eval import PanopticEvaluator
|
| 23 |
+
from models.utils_softnms import batched_soft_nms
|
| 24 |
+
from util.misc import NestedTensor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Make sure this is consistent with datasets/coco.py
|
| 28 |
+
# TODO: make it configurable
|
| 29 |
+
SCALE_RANGES_DICT = {
|
| 30 |
+
1728: [[0, 10000], [32, 10000], [32, 10000],],
|
| 31 |
+
1824: [[0, 10000], [0, 10000], [64, 10000], [64, 10000],],
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def filter_boxes(boxes, min_scale, max_scale):
|
| 36 |
+
"""
|
| 37 |
+
boxes: (N, 4) shape
|
| 38 |
+
"""
|
| 39 |
+
w = boxes[:, 2] - boxes[:, 0]
|
| 40 |
+
h = boxes[:, 3] - boxes[:, 1]
|
| 41 |
+
keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
|
| 42 |
+
return keep
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def evaluate_tta(
|
| 47 |
+
model_no_ema,
|
| 48 |
+
criterion,
|
| 49 |
+
postprocessors,
|
| 50 |
+
data_loader,
|
| 51 |
+
base_ds,
|
| 52 |
+
device,
|
| 53 |
+
output_dir,
|
| 54 |
+
test_hflip_aug,
|
| 55 |
+
tta,
|
| 56 |
+
soft_nms,
|
| 57 |
+
ema=None,
|
| 58 |
+
save_result=False,
|
| 59 |
+
save_result_dir="",
|
| 60 |
+
soft_nms_method="quad",
|
| 61 |
+
nms_thresh=0.7,
|
| 62 |
+
quad_scale=0.5,
|
| 63 |
+
lsj_img_size=1824,
|
| 64 |
+
):
|
| 65 |
+
model = model_no_ema if ema is None else ema
|
| 66 |
+
model.eval()
|
| 67 |
+
criterion.eval()
|
| 68 |
+
|
| 69 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 70 |
+
metric_logger.add_meter(
|
| 71 |
+
"class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
|
| 72 |
+
)
|
| 73 |
+
header = "Test:"
|
| 74 |
+
|
| 75 |
+
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
|
| 76 |
+
coco_evaluator = CocoEvaluator(base_ds, iou_types)
|
| 77 |
+
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
|
| 78 |
+
|
| 79 |
+
SCALE_RANGES = SCALE_RANGES_DICT[lsj_img_size]
|
| 80 |
+
IMAGE_SIZE = [lsj_img_size for _ in range(len(SCALE_RANGES))]
|
| 81 |
+
|
| 82 |
+
prediction_list = []
|
| 83 |
+
for samples, targets in metric_logger.log_every(data_loader, 10, header):
|
| 84 |
+
samples = samples.to(device)
|
| 85 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 86 |
+
|
| 87 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 88 |
+
metric_logger.update(loss=0, class_error=0, loss_bbox=0, loss_ce=0)
|
| 89 |
+
########################### Begin of inference_one_image ###########################
|
| 90 |
+
if tta:
|
| 91 |
+
assert samples.tensors.shape[0] == 1, "tta only supports batch size 1"
|
| 92 |
+
assert (
|
| 93 |
+
samples.tensors.shape[1] % 3 == 0
|
| 94 |
+
), "tta requires dimensions of samples.tensors to be divisible by 3"
|
| 95 |
+
|
| 96 |
+
all_boxes = []
|
| 97 |
+
all_scores = []
|
| 98 |
+
all_classes = []
|
| 99 |
+
|
| 100 |
+
num_scales = samples.tensors.shape[1] // 3
|
| 101 |
+
for scale_ind in range(num_scales):
|
| 102 |
+
first_samples = NestedTensor(
|
| 103 |
+
samples.tensors[
|
| 104 |
+
:,
|
| 105 |
+
scale_ind * 3 : (scale_ind + 1) * 3,
|
| 106 |
+
: IMAGE_SIZE[scale_ind // 2],
|
| 107 |
+
: IMAGE_SIZE[scale_ind // 2],
|
| 108 |
+
],
|
| 109 |
+
samples.mask[
|
| 110 |
+
:,
|
| 111 |
+
scale_ind,
|
| 112 |
+
: IMAGE_SIZE[scale_ind // 2],
|
| 113 |
+
: IMAGE_SIZE[scale_ind // 2],
|
| 114 |
+
],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if scale_ind % 2 == 0:
|
| 118 |
+
######## no flip #######
|
| 119 |
+
outputs = model(first_samples)
|
| 120 |
+
noaug_results = postprocessors["bbox"](
|
| 121 |
+
outputs,
|
| 122 |
+
orig_target_sizes,
|
| 123 |
+
soft_nms=soft_nms,
|
| 124 |
+
method=soft_nms_method,
|
| 125 |
+
nms_thresh=nms_thresh,
|
| 126 |
+
quad_scale=quad_scale,
|
| 127 |
+
)
|
| 128 |
+
keep = filter_boxes(
|
| 129 |
+
noaug_results[0]["boxes"], *SCALE_RANGES[scale_ind // 2]
|
| 130 |
+
)
|
| 131 |
+
all_boxes.append(noaug_results[0]["boxes"][keep])
|
| 132 |
+
all_scores.append(noaug_results[0]["scores"][keep])
|
| 133 |
+
all_classes.append(noaug_results[0]["labels"][keep])
|
| 134 |
+
else:
|
| 135 |
+
######## flipped #######
|
| 136 |
+
flipped_outputs = model(first_samples)
|
| 137 |
+
flipped_pred_logits = flipped_outputs["pred_logits"]
|
| 138 |
+
flipped_pred_boxes = flipped_outputs["pred_boxes"]
|
| 139 |
+
reflipped_pred_boxes = flipped_pred_boxes[
|
| 140 |
+
:, :, [0, 1, 2, 3]
|
| 141 |
+
] * torch.as_tensor([-1, 1, 1, 1]).to(
|
| 142 |
+
flipped_pred_boxes.device
|
| 143 |
+
) + torch.as_tensor(
|
| 144 |
+
[1, 0, 0, 0]
|
| 145 |
+
).to(
|
| 146 |
+
flipped_pred_boxes.device
|
| 147 |
+
)
|
| 148 |
+
new_outputs = {}
|
| 149 |
+
new_outputs["pred_logits"] = flipped_pred_logits
|
| 150 |
+
new_outputs["pred_boxes"] = reflipped_pred_boxes
|
| 151 |
+
new_results = postprocessors["bbox"](
|
| 152 |
+
new_outputs,
|
| 153 |
+
orig_target_sizes,
|
| 154 |
+
soft_nms=soft_nms,
|
| 155 |
+
method=soft_nms_method,
|
| 156 |
+
nms_thresh=nms_thresh,
|
| 157 |
+
quad_scale=quad_scale,
|
| 158 |
+
)
|
| 159 |
+
keep = filter_boxes(
|
| 160 |
+
new_results[0]["boxes"], *SCALE_RANGES[scale_ind // 2]
|
| 161 |
+
)
|
| 162 |
+
all_boxes.append(new_results[0]["boxes"][keep])
|
| 163 |
+
all_scores.append(new_results[0]["scores"][keep])
|
| 164 |
+
all_classes.append(new_results[0]["labels"][keep])
|
| 165 |
+
|
| 166 |
+
######## merge #######
|
| 167 |
+
all_boxes = torch.cat(all_boxes, dim=0)
|
| 168 |
+
all_scores = torch.cat(all_scores, dim=0)
|
| 169 |
+
all_classes = torch.cat(all_classes, dim=0)
|
| 170 |
+
|
| 171 |
+
keep_inds, updated_scores = batched_soft_nms(
|
| 172 |
+
all_boxes,
|
| 173 |
+
all_scores,
|
| 174 |
+
all_classes,
|
| 175 |
+
method=soft_nms_method,
|
| 176 |
+
threshold=nms_thresh,
|
| 177 |
+
quad_scale=quad_scale,
|
| 178 |
+
)
|
| 179 |
+
merged_scores = updated_scores
|
| 180 |
+
merged_classes = all_classes[keep_inds]
|
| 181 |
+
merged_boxes = all_boxes[keep_inds]
|
| 182 |
+
|
| 183 |
+
results = [
|
| 184 |
+
{
|
| 185 |
+
"boxes": merged_boxes,
|
| 186 |
+
"scores": merged_scores,
|
| 187 |
+
"labels": merged_classes,
|
| 188 |
+
}
|
| 189 |
+
]
|
| 190 |
+
else:
|
| 191 |
+
outputs = model(samples)
|
| 192 |
+
results = postprocessors["bbox"](outputs, orig_target_sizes)
|
| 193 |
+
|
| 194 |
+
########################### End of inference_one_image ###########################
|
| 195 |
+
res = {
|
| 196 |
+
target["image_id"].item(): output
|
| 197 |
+
for target, output in zip(targets, results)
|
| 198 |
+
}
|
| 199 |
+
if coco_evaluator is not None:
|
| 200 |
+
coco_evaluator.update(res)
|
| 201 |
+
|
| 202 |
+
for target, output in zip(targets, results):
|
| 203 |
+
res_cpu = {
|
| 204 |
+
target["image_id"].item(): {
|
| 205 |
+
"boxes": output["boxes"].cpu(),
|
| 206 |
+
"labels": output["labels"].cpu(),
|
| 207 |
+
"scores": output["scores"].cpu(),
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
prediction_list.append(res_cpu)
|
| 211 |
+
|
| 212 |
+
# gather the stats from all processes
|
| 213 |
+
metric_logger.synchronize_between_processes()
|
| 214 |
+
print("Averaged stats:", metric_logger)
|
| 215 |
+
|
| 216 |
+
if save_result:
|
| 217 |
+
from torch import distributed as dist
|
| 218 |
+
|
| 219 |
+
os.makedirs(save_result_dir, exist_ok=True)
|
| 220 |
+
|
| 221 |
+
rank = dist.get_rank()
|
| 222 |
+
torch.save(
|
| 223 |
+
prediction_list,
|
| 224 |
+
os.path.join(save_result_dir, f"val2017_prediction_{rank}.pth"),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if coco_evaluator is not None:
|
| 228 |
+
coco_evaluator.synchronize_between_processes()
|
| 229 |
+
|
| 230 |
+
# accumulate predictions from all images
|
| 231 |
+
if coco_evaluator is not None:
|
| 232 |
+
coco_evaluator.accumulate()
|
| 233 |
+
coco_evaluator.summarize()
|
| 234 |
+
|
| 235 |
+
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 236 |
+
if coco_evaluator is not None:
|
| 237 |
+
if "bbox" in postprocessors.keys():
|
| 238 |
+
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
|
| 239 |
+
return stats, coco_evaluator
|
perception_models/apps/detection/DETA_pe/main.py
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from
|
| 2 |
+
# ------------------------------------------------------------------------
|
| 3 |
+
# Deformable DETR
|
| 4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
# ------------------------------------------------------------------------
|
| 7 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 8 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 9 |
+
# ------------------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import datetime
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import time
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import datasets
|
| 22 |
+
import datasets.samplers as samplers
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import util.misc as utils
|
| 27 |
+
from datasets import build_dataset, get_coco_api_from_dataset
|
| 28 |
+
from engine import evaluate, train_one_epoch
|
| 29 |
+
from engine_tta import evaluate_tta
|
| 30 |
+
from models import build_model
|
| 31 |
+
from torch.utils.data import DataLoader
|
| 32 |
+
from util.ema import requires_grad, update_ema
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_args_parser():
|
| 36 |
+
parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False)
|
| 37 |
+
parser.add_argument("--lr", default=2e-4, type=float)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--lr_backbone_names", default=["backbone.0"], type=str, nargs="+"
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--lr_backbone", default=2e-5, type=float)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--lr_linear_proj_names",
|
| 44 |
+
default=["reference_points", "sampling_offsets"],
|
| 45 |
+
type=str,
|
| 46 |
+
nargs="+",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float)
|
| 49 |
+
parser.add_argument("--batch_size", default=2, type=int)
|
| 50 |
+
parser.add_argument("--weight_decay", default=1e-4, type=float)
|
| 51 |
+
parser.add_argument("--epochs", default=50, type=int)
|
| 52 |
+
parser.add_argument("--eval_per_epochs", default=1, type=int)
|
| 53 |
+
parser.add_argument("--save_per_epochs", default=1, type=int)
|
| 54 |
+
parser.add_argument("--lr_drop", default=40, type=int)
|
| 55 |
+
parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+")
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
parser.add_argument("--sgd", action="store_true")
|
| 61 |
+
parser.add_argument("--ema", action="store_true")
|
| 62 |
+
parser.add_argument("--ema_decay", default=0.999, type=float)
|
| 63 |
+
|
| 64 |
+
# Variants of Deformable DETR
|
| 65 |
+
parser.add_argument("--with_box_refine", default=False, action="store_true")
|
| 66 |
+
parser.add_argument("--two_stage", default=False, action="store_true")
|
| 67 |
+
|
| 68 |
+
# Model parameters
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--frozen_weights",
|
| 71 |
+
type=str,
|
| 72 |
+
default=None,
|
| 73 |
+
help="Path to the pretrained model. If set, only the mask head will be trained",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# * Backbone
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--backbone",
|
| 79 |
+
default="resnet50",
|
| 80 |
+
type=str,
|
| 81 |
+
help="Name of the convolutional backbone to use",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--backbone_size",
|
| 85 |
+
default="Gwin384",
|
| 86 |
+
type=str,
|
| 87 |
+
help="backbone size",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--backbone_path",
|
| 91 |
+
default="",
|
| 92 |
+
type=str,
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--backbone_lrd",
|
| 96 |
+
default=1.0,
|
| 97 |
+
type=float,
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--backbone_layers",
|
| 101 |
+
default=12,
|
| 102 |
+
type=int,
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--backbone_init_values",
|
| 106 |
+
default=0.0,
|
| 107 |
+
type=float,
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--backbone_tile_posemb",
|
| 111 |
+
default=False,
|
| 112 |
+
type=bool,
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--backbone_use_act_checkpoint",
|
| 116 |
+
action="store_true",
|
| 117 |
+
help="If true, we use act_checkpoint in backbone",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--backbone_act_checkpoint_ratio",
|
| 121 |
+
default=1.0,
|
| 122 |
+
type=float,
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--backbone_tta_rope",
|
| 126 |
+
action="store_true",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--backbone_multi_layer",
|
| 130 |
+
action="store_true",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--backbone_win_aug",
|
| 135 |
+
action="store_true",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--backbone_dp",
|
| 140 |
+
default=-1.0,
|
| 141 |
+
type=float,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--bf16",
|
| 146 |
+
action="store_true",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--fp16",
|
| 150 |
+
action="store_true",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--dilation",
|
| 154 |
+
action="store_true",
|
| 155 |
+
help="If true, we replace stride with dilation in the last convolutional block (DC5)",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--position_embedding",
|
| 159 |
+
default="sine",
|
| 160 |
+
type=str,
|
| 161 |
+
choices=("sine", "learned"),
|
| 162 |
+
help="Type of positional embedding to use on top of the image features",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--position_embedding_scale",
|
| 166 |
+
default=2 * np.pi,
|
| 167 |
+
type=float,
|
| 168 |
+
help="position / size * scale",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--num_feature_levels", default=4, type=int, help="number of feature levels"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# * Transformer
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--enc_layers",
|
| 177 |
+
default=6,
|
| 178 |
+
type=int,
|
| 179 |
+
help="Number of encoding layers in the transformer",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--dec_layers",
|
| 183 |
+
default=6,
|
| 184 |
+
type=int,
|
| 185 |
+
help="Number of decoding layers in the transformer",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--dim_feedforward",
|
| 189 |
+
default=1024,
|
| 190 |
+
type=int,
|
| 191 |
+
help="Intermediate size of the feedforward layers in the transformer blocks",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--hidden_dim",
|
| 195 |
+
default=256,
|
| 196 |
+
type=int,
|
| 197 |
+
help="Size of the embeddings (dimension of the transformer)",
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--nheads",
|
| 204 |
+
default=8,
|
| 205 |
+
type=int,
|
| 206 |
+
help="Number of attention heads inside the transformer's attentions",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--num_queries", default=300, type=int, help="Number of query slots"
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument("--dec_n_points", default=4, type=int)
|
| 212 |
+
parser.add_argument("--enc_n_points", default=4, type=int)
|
| 213 |
+
|
| 214 |
+
# * Segmentation
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--masks",
|
| 217 |
+
action="store_true",
|
| 218 |
+
help="Train segmentation head if the flag is provided",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Loss
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--no_aux_loss",
|
| 224 |
+
dest="aux_loss",
|
| 225 |
+
action="store_false",
|
| 226 |
+
help="Disables auxiliary decoding losses (loss at each layer)",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument("--use_fed_loss", action="store_true")
|
| 229 |
+
|
| 230 |
+
# * Matcher
|
| 231 |
+
parser.add_argument("--assign_first_stage", action="store_true")
|
| 232 |
+
parser.add_argument("--assign_second_stage", action="store_true")
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--set_cost_class",
|
| 235 |
+
default=2,
|
| 236 |
+
type=float,
|
| 237 |
+
help="Class coefficient in the matching cost",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--set_cost_bbox",
|
| 241 |
+
default=5,
|
| 242 |
+
type=float,
|
| 243 |
+
help="L1 box coefficient in the matching cost",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--set_cost_giou",
|
| 247 |
+
default=2,
|
| 248 |
+
type=float,
|
| 249 |
+
help="giou box coefficient in the matching cost",
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# * Loss coefficients
|
| 253 |
+
parser.add_argument("--mask_loss_coef", default=1, type=float)
|
| 254 |
+
parser.add_argument("--dice_loss_coef", default=1, type=float)
|
| 255 |
+
parser.add_argument("--cls_loss_coef", default=2, type=float)
|
| 256 |
+
parser.add_argument("--bbox_loss_coef", default=5, type=float)
|
| 257 |
+
parser.add_argument("--giou_loss_coef", default=2, type=float)
|
| 258 |
+
parser.add_argument("--focal_alpha", default=0.25, type=float)
|
| 259 |
+
|
| 260 |
+
# dataset parameters
|
| 261 |
+
parser.add_argument("--new_mean_std", action="store_true")
|
| 262 |
+
parser.add_argument("--dataset_file", default="coco")
|
| 263 |
+
parser.add_argument("--coco_path", default="./data/coco", type=str)
|
| 264 |
+
parser.add_argument("--coco_panoptic_path", type=str)
|
| 265 |
+
parser.add_argument("--remove_difficult", action="store_true")
|
| 266 |
+
parser.add_argument("--bigger", action="store_true")
|
| 267 |
+
parser.add_argument("--lsj", action="store_true")
|
| 268 |
+
parser.add_argument("--lsj_ms", action="store_true")
|
| 269 |
+
|
| 270 |
+
parser.add_argument("--lsj_img_size", default=1024, type=int)
|
| 271 |
+
parser.add_argument("--lsj_img_train_min", default=480, type=int)
|
| 272 |
+
parser.add_argument("--lsj_img_size_max", default=-1, type=int)
|
| 273 |
+
parser.add_argument("--lsj_strong_aug", action="store_true")
|
| 274 |
+
|
| 275 |
+
parser.add_argument("--save_result", action="store_true")
|
| 276 |
+
parser.add_argument("--save_result_dir", default="", type=str)
|
| 277 |
+
parser.add_argument("--test_hflip_aug", action="store_true")
|
| 278 |
+
parser.add_argument("--tta", action="store_true")
|
| 279 |
+
parser.add_argument("--soft_nms", action="store_true")
|
| 280 |
+
parser.add_argument("--soft_nms_method", default="quad", type=str)
|
| 281 |
+
parser.add_argument("--nms_thresh", default=0.7, type=float)
|
| 282 |
+
parser.add_argument("--quad_scale", default=0.5, type=float)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--output_dir", default="", help="path where to save, empty for no saving"
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--device", default="cuda", help="device to use for training / testing"
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument("--seed", default=42, type=int)
|
| 290 |
+
parser.add_argument("--resume", default="", help="resume from checkpoint")
|
| 291 |
+
parser.add_argument("--auto_resume", action="store_true")
|
| 292 |
+
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--resume_norope",
|
| 295 |
+
action="store_true",
|
| 296 |
+
help="resume from checkpoint without rope params",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument("--finetune", default="", help="finetune from checkpoint")
|
| 299 |
+
parser.add_argument("--keep_class_embed", action="store_true")
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
"--start_epoch", default=0, type=int, metavar="N", help="start epoch"
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument("--eval", action="store_true")
|
| 304 |
+
parser.add_argument("--num_workers", default=8, type=int)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
"--cache_mode",
|
| 307 |
+
default=False,
|
| 308 |
+
action="store_true",
|
| 309 |
+
help="whether to cache images on memory",
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return parser
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
|
| 316 |
+
def match_name_keywords(n, name_keywords):
|
| 317 |
+
out = False
|
| 318 |
+
for b in name_keywords:
|
| 319 |
+
if b in n:
|
| 320 |
+
out = True
|
| 321 |
+
break
|
| 322 |
+
return out
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def get_vit_lr_decay_rate_vev01(name, lr_decay_rate=1.0, num_layers=12):
|
| 326 |
+
layer_id = num_layers + 1
|
| 327 |
+
if ".positional_embedding" in name or ".conv1" in name or ".ln_pre" in name:
|
| 328 |
+
layer_id = 0
|
| 329 |
+
elif ".resblocks." in name:
|
| 330 |
+
layer_id = int(name[name.find(".resblocks.") :].split(".")[2]) + 1
|
| 331 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def custom_lr(model_without_ddp, args):
|
| 335 |
+
param_dicts = [
|
| 336 |
+
{
|
| 337 |
+
"params": [
|
| 338 |
+
p
|
| 339 |
+
for n, p in model_without_ddp.named_parameters()
|
| 340 |
+
if not match_name_keywords(n, args.lr_backbone_names)
|
| 341 |
+
and not match_name_keywords(n, args.lr_linear_proj_names)
|
| 342 |
+
and p.requires_grad
|
| 343 |
+
],
|
| 344 |
+
"lr": args.lr,
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"params": [
|
| 348 |
+
p
|
| 349 |
+
for n, p in model_without_ddp.named_parameters()
|
| 350 |
+
if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad
|
| 351 |
+
],
|
| 352 |
+
"lr": args.lr * args.lr_linear_proj_mult,
|
| 353 |
+
},
|
| 354 |
+
]
|
| 355 |
+
if "vev01" in args.backbone:
|
| 356 |
+
for p_key, p_value in model_without_ddp.named_parameters():
|
| 357 |
+
if (
|
| 358 |
+
match_name_keywords(p_key, args.lr_backbone_names)
|
| 359 |
+
and p_value.requires_grad
|
| 360 |
+
):
|
| 361 |
+
p_lr = args.lr_backbone * get_vit_lr_decay_rate_vev01(
|
| 362 |
+
p_key, args.backbone_lrd, args.backbone_layers
|
| 363 |
+
)
|
| 364 |
+
param_dicts.append(
|
| 365 |
+
{
|
| 366 |
+
"params": [p_value],
|
| 367 |
+
"lr": p_lr,
|
| 368 |
+
}
|
| 369 |
+
)
|
| 370 |
+
print(f"param_name: {p_key}, lr: {p_lr}")
|
| 371 |
+
else:
|
| 372 |
+
param_groups_backbone = {
|
| 373 |
+
"params": [
|
| 374 |
+
p
|
| 375 |
+
for n, p in model_without_ddp.named_parameters()
|
| 376 |
+
if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad
|
| 377 |
+
],
|
| 378 |
+
"lr": args.lr_backbone,
|
| 379 |
+
}
|
| 380 |
+
param_dicts.append(param_groups_backbone)
|
| 381 |
+
|
| 382 |
+
return param_dicts
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def main(args):
|
| 386 |
+
utils.init_distributed_mode(args)
|
| 387 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
| 388 |
+
|
| 389 |
+
if args.frozen_weights is not None:
|
| 390 |
+
assert args.masks, "Frozen training is meant for segmentation only"
|
| 391 |
+
print(args)
|
| 392 |
+
|
| 393 |
+
device = torch.device(args.device)
|
| 394 |
+
|
| 395 |
+
# fix the seed for reproducibility
|
| 396 |
+
seed = args.seed + utils.get_rank()
|
| 397 |
+
torch.manual_seed(seed)
|
| 398 |
+
np.random.seed(seed)
|
| 399 |
+
random.seed(seed)
|
| 400 |
+
|
| 401 |
+
model, criterion, postprocessors = build_model(args)
|
| 402 |
+
model.to(device)
|
| 403 |
+
|
| 404 |
+
model_without_ddp = model
|
| 405 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 406 |
+
print("model:", model_without_ddp)
|
| 407 |
+
for n, p in model_without_ddp.named_parameters():
|
| 408 |
+
print(n)
|
| 409 |
+
print("number of params:", n_parameters)
|
| 410 |
+
|
| 411 |
+
if args.ema:
|
| 412 |
+
ema = deepcopy(model).to(device)
|
| 413 |
+
requires_grad(ema, False)
|
| 414 |
+
print(f"EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
|
| 415 |
+
|
| 416 |
+
dataset_train = build_dataset(image_set="train", args=args)
|
| 417 |
+
dataset_val = build_dataset(image_set="val", args=args)
|
| 418 |
+
|
| 419 |
+
if args.distributed:
|
| 420 |
+
if args.cache_mode:
|
| 421 |
+
sampler_train = samplers.NodeDistributedSampler(dataset_train)
|
| 422 |
+
sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
|
| 423 |
+
else:
|
| 424 |
+
if args.dataset_file == "lvis":
|
| 425 |
+
sampler_train = samplers.RepeatFactorTrainingSampler(dataset_train)
|
| 426 |
+
else:
|
| 427 |
+
sampler_train = samplers.DistributedSampler(dataset_train)
|
| 428 |
+
sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
|
| 429 |
+
else:
|
| 430 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 431 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 432 |
+
|
| 433 |
+
batch_sampler_train = torch.utils.data.BatchSampler(
|
| 434 |
+
sampler_train, args.batch_size, drop_last=True
|
| 435 |
+
)
|
| 436 |
+
if args.lsj_ms:
|
| 437 |
+
collator = utils.CollatorLSJMultiscale(args.lsj_img_size, args.tta)
|
| 438 |
+
elif args.lsj:
|
| 439 |
+
lsj_img_size_colla = (
|
| 440 |
+
args.lsj_img_size_max if args.lsj_img_size_max > 0 else args.lsj_img_size
|
| 441 |
+
)
|
| 442 |
+
collator = utils.CollatorLSJ(lsj_img_size_colla, args.tta)
|
| 443 |
+
else:
|
| 444 |
+
collator = utils.collate_fn
|
| 445 |
+
|
| 446 |
+
data_loader_train = DataLoader(
|
| 447 |
+
dataset_train,
|
| 448 |
+
batch_sampler=batch_sampler_train,
|
| 449 |
+
collate_fn=collator,
|
| 450 |
+
num_workers=args.num_workers,
|
| 451 |
+
pin_memory=True,
|
| 452 |
+
)
|
| 453 |
+
data_loader_val = DataLoader(
|
| 454 |
+
dataset_val,
|
| 455 |
+
args.batch_size,
|
| 456 |
+
sampler=sampler_val,
|
| 457 |
+
drop_last=False,
|
| 458 |
+
collate_fn=collator,
|
| 459 |
+
num_workers=args.num_workers,
|
| 460 |
+
pin_memory=True,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
param_dicts = custom_lr(model_without_ddp, args)
|
| 464 |
+
|
| 465 |
+
if args.sgd:
|
| 466 |
+
optimizer = torch.optim.SGD(
|
| 467 |
+
param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
optimizer = torch.optim.AdamW(
|
| 471 |
+
param_dicts, lr=args.lr, weight_decay=args.weight_decay
|
| 472 |
+
)
|
| 473 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
|
| 474 |
+
|
| 475 |
+
if args.distributed:
|
| 476 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 477 |
+
model_without_ddp = model.module
|
| 478 |
+
|
| 479 |
+
if args.dataset_file == "coco_panoptic":
|
| 480 |
+
# We also evaluate AP during panoptic training, on original coco DS
|
| 481 |
+
coco_val = datasets.coco.build("val", args)
|
| 482 |
+
base_ds = get_coco_api_from_dataset(coco_val)
|
| 483 |
+
else:
|
| 484 |
+
base_ds = get_coco_api_from_dataset(dataset_val)
|
| 485 |
+
|
| 486 |
+
if args.frozen_weights is not None:
|
| 487 |
+
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
|
| 488 |
+
model_without_ddp.detr.load_state_dict(checkpoint["model"])
|
| 489 |
+
|
| 490 |
+
if args.tta:
|
| 491 |
+
evaluate_fn = evaluate_tta
|
| 492 |
+
else:
|
| 493 |
+
evaluate_fn = evaluate
|
| 494 |
+
|
| 495 |
+
output_dir = Path(args.output_dir)
|
| 496 |
+
if args.auto_resume:
|
| 497 |
+
resumed_ckpt = os.path.join(args.output_dir, "checkpoint.pth")
|
| 498 |
+
if os.path.exists(resumed_ckpt):
|
| 499 |
+
args.resume = resumed_ckpt
|
| 500 |
+
args.finetune = None
|
| 501 |
+
|
| 502 |
+
if args.finetune:
|
| 503 |
+
checkpoint = torch.load(args.finetune, map_location="cpu")
|
| 504 |
+
state_dict = checkpoint["model"]
|
| 505 |
+
for k in list(state_dict.keys()):
|
| 506 |
+
if "class_embed" in k and not args.keep_class_embed:
|
| 507 |
+
print("removing", k)
|
| 508 |
+
del state_dict[k]
|
| 509 |
+
if "freqs" in k:
|
| 510 |
+
print("removing", k)
|
| 511 |
+
del state_dict[k]
|
| 512 |
+
|
| 513 |
+
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
|
| 514 |
+
state_dict, strict=False
|
| 515 |
+
)
|
| 516 |
+
unexpected_keys = [
|
| 517 |
+
k
|
| 518 |
+
for k in unexpected_keys
|
| 519 |
+
if not (k.endswith("total_params") or k.endswith("total_ops"))
|
| 520 |
+
]
|
| 521 |
+
if len(missing_keys) > 0:
|
| 522 |
+
print("Missing Keys: {}".format(missing_keys))
|
| 523 |
+
if len(unexpected_keys) > 0:
|
| 524 |
+
print("Unexpected Keys: {}".format(unexpected_keys))
|
| 525 |
+
|
| 526 |
+
if "epoch" in checkpoint:
|
| 527 |
+
print("finetuning from epoch", checkpoint["epoch"])
|
| 528 |
+
|
| 529 |
+
if args.ema:
|
| 530 |
+
ema.load_state_dict(
|
| 531 |
+
checkpoint["ema"] if "ema" in checkpoint else state_dict, strict=False
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
if args.resume:
|
| 535 |
+
print("Resuming training from {}".format(args.resume))
|
| 536 |
+
if args.resume.startswith("https"):
|
| 537 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 538 |
+
args.resume, map_location="cpu", check_hash=True
|
| 539 |
+
)
|
| 540 |
+
else:
|
| 541 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
| 542 |
+
|
| 543 |
+
if args.resume_norope:
|
| 544 |
+
state_dict = checkpoint["model"]
|
| 545 |
+
for k in list(state_dict.keys()):
|
| 546 |
+
if "freqs" in k:
|
| 547 |
+
print("removing", k)
|
| 548 |
+
del state_dict[k]
|
| 549 |
+
|
| 550 |
+
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
|
| 551 |
+
state_dict, strict=False
|
| 552 |
+
)
|
| 553 |
+
if args.ema:
|
| 554 |
+
ema.load_state_dict(
|
| 555 |
+
checkpoint["ema"] if "ema" in checkpoint else state_dict,
|
| 556 |
+
strict=False,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
|
| 560 |
+
checkpoint["model"], strict=False
|
| 561 |
+
)
|
| 562 |
+
if args.ema:
|
| 563 |
+
ema.load_state_dict(
|
| 564 |
+
checkpoint["ema"] if "ema" in checkpoint else state_dict,
|
| 565 |
+
strict=False,
|
| 566 |
+
)
|
| 567 |
+
unexpected_keys = [
|
| 568 |
+
k
|
| 569 |
+
for k in unexpected_keys
|
| 570 |
+
if not (k.endswith("total_params") or k.endswith("total_ops"))
|
| 571 |
+
]
|
| 572 |
+
if len(missing_keys) > 0:
|
| 573 |
+
print("Missing Keys: {}".format(missing_keys))
|
| 574 |
+
if len(unexpected_keys) > 0:
|
| 575 |
+
print("Unexpected Keys: {}".format(unexpected_keys))
|
| 576 |
+
if (
|
| 577 |
+
not args.eval
|
| 578 |
+
and "optimizer" in checkpoint
|
| 579 |
+
and "lr_scheduler" in checkpoint
|
| 580 |
+
and "epoch" in checkpoint
|
| 581 |
+
):
|
| 582 |
+
import copy
|
| 583 |
+
|
| 584 |
+
p_groups = copy.deepcopy(optimizer.param_groups)
|
| 585 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 586 |
+
for pg, pg_old in zip(optimizer.param_groups, p_groups):
|
| 587 |
+
pg["lr"] = pg_old["lr"]
|
| 588 |
+
pg["initial_lr"] = pg_old["initial_lr"]
|
| 589 |
+
print(optimizer.param_groups)
|
| 590 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 591 |
+
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
|
| 592 |
+
args.override_resumed_lr_drop = True
|
| 593 |
+
if args.override_resumed_lr_drop:
|
| 594 |
+
print(
|
| 595 |
+
"Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler."
|
| 596 |
+
)
|
| 597 |
+
lr_scheduler.step_size = args.lr_drop
|
| 598 |
+
lr_scheduler.base_lrs = list(
|
| 599 |
+
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
| 600 |
+
)
|
| 601 |
+
lr_scheduler.step(lr_scheduler.last_epoch)
|
| 602 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
| 603 |
+
# check the resumed model
|
| 604 |
+
if not args.eval:
|
| 605 |
+
test_stats, coco_evaluator = evaluate_fn(
|
| 606 |
+
model,
|
| 607 |
+
criterion,
|
| 608 |
+
postprocessors,
|
| 609 |
+
data_loader_val,
|
| 610 |
+
base_ds,
|
| 611 |
+
device,
|
| 612 |
+
args.output_dir,
|
| 613 |
+
args.test_hflip_aug,
|
| 614 |
+
args.tta,
|
| 615 |
+
args.soft_nms,
|
| 616 |
+
ema if args.ema else None,
|
| 617 |
+
args.save_result,
|
| 618 |
+
args.save_result_dir,
|
| 619 |
+
soft_nms_method=args.soft_nms_method,
|
| 620 |
+
nms_thresh=args.nms_thresh,
|
| 621 |
+
quad_scale=args.quad_scale,
|
| 622 |
+
lsj_img_size=args.lsj_img_size,
|
| 623 |
+
)
|
| 624 |
+
torch.cuda.empty_cache()
|
| 625 |
+
|
| 626 |
+
if args.eval:
|
| 627 |
+
test_stats, coco_evaluator = evaluate_fn(
|
| 628 |
+
model,
|
| 629 |
+
criterion,
|
| 630 |
+
postprocessors,
|
| 631 |
+
data_loader_val,
|
| 632 |
+
base_ds,
|
| 633 |
+
device,
|
| 634 |
+
args.output_dir,
|
| 635 |
+
args.test_hflip_aug,
|
| 636 |
+
args.tta,
|
| 637 |
+
args.soft_nms,
|
| 638 |
+
ema if args.ema else None,
|
| 639 |
+
args.save_result,
|
| 640 |
+
args.save_result_dir,
|
| 641 |
+
soft_nms_method=args.soft_nms_method,
|
| 642 |
+
nms_thresh=args.nms_thresh,
|
| 643 |
+
quad_scale=args.quad_scale,
|
| 644 |
+
lsj_img_size=args.lsj_img_size,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
if args.output_dir:
|
| 648 |
+
utils.save_on_master(
|
| 649 |
+
coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth"
|
| 650 |
+
)
|
| 651 |
+
return
|
| 652 |
+
|
| 653 |
+
print("Start training")
|
| 654 |
+
start_time = time.time()
|
| 655 |
+
if args.ema:
|
| 656 |
+
ema.eval() # EMA model should always be in eval mode
|
| 657 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 658 |
+
if args.distributed:
|
| 659 |
+
sampler_train.set_epoch(epoch)
|
| 660 |
+
train_stats = train_one_epoch(
|
| 661 |
+
model,
|
| 662 |
+
criterion,
|
| 663 |
+
data_loader_train,
|
| 664 |
+
optimizer,
|
| 665 |
+
device,
|
| 666 |
+
epoch,
|
| 667 |
+
args.clip_max_norm,
|
| 668 |
+
ema if args.ema else None,
|
| 669 |
+
ema_decay=args.ema_decay,
|
| 670 |
+
)
|
| 671 |
+
lr_scheduler.step()
|
| 672 |
+
if args.output_dir:
|
| 673 |
+
checkpoint_paths = [output_dir / "checkpoint.pth"]
|
| 674 |
+
# extra checkpoint before LR drop and every 5 epochs
|
| 675 |
+
if (
|
| 676 |
+
(epoch + 1) % args.lr_drop == 0
|
| 677 |
+
or (epoch + 1) % args.save_per_epochs == 0
|
| 678 |
+
or epoch + 1 == args.epochs
|
| 679 |
+
):
|
| 680 |
+
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
|
| 681 |
+
for checkpoint_path in checkpoint_paths:
|
| 682 |
+
ckpt_dict = {
|
| 683 |
+
"model": model_without_ddp.state_dict(),
|
| 684 |
+
"optimizer": optimizer.state_dict(),
|
| 685 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
| 686 |
+
"epoch": epoch,
|
| 687 |
+
"args": args,
|
| 688 |
+
}
|
| 689 |
+
if args.ema:
|
| 690 |
+
ckpt_dict["ema"] = ema.state_dict()
|
| 691 |
+
utils.save_on_master(
|
| 692 |
+
ckpt_dict,
|
| 693 |
+
checkpoint_path,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
torch.cuda.empty_cache()
|
| 697 |
+
if epoch % args.eval_per_epochs == 0 or epoch + 1 == args.epochs:
|
| 698 |
+
test_stats, coco_evaluator = evaluate_fn(
|
| 699 |
+
model,
|
| 700 |
+
criterion,
|
| 701 |
+
postprocessors,
|
| 702 |
+
data_loader_val,
|
| 703 |
+
base_ds,
|
| 704 |
+
device,
|
| 705 |
+
args.output_dir,
|
| 706 |
+
args.test_hflip_aug,
|
| 707 |
+
args.tta,
|
| 708 |
+
args.soft_nms,
|
| 709 |
+
ema if args.ema else None,
|
| 710 |
+
args.save_result,
|
| 711 |
+
args.save_result_dir,
|
| 712 |
+
soft_nms_method=args.soft_nms_method,
|
| 713 |
+
nms_thresh=args.nms_thresh,
|
| 714 |
+
quad_scale=args.quad_scale,
|
| 715 |
+
lsj_img_size=args.lsj_img_size,
|
| 716 |
+
)
|
| 717 |
+
log_stats = {
|
| 718 |
+
**{f"train_{k}": v for k, v in train_stats.items()},
|
| 719 |
+
**{f"test_{k}": v for k, v in test_stats.items()},
|
| 720 |
+
"epoch": epoch,
|
| 721 |
+
"n_parameters": n_parameters,
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
if args.output_dir and utils.is_main_process():
|
| 725 |
+
with (output_dir / "log.txt").open("a") as f:
|
| 726 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 727 |
+
|
| 728 |
+
# for evaluation logs
|
| 729 |
+
if coco_evaluator is not None:
|
| 730 |
+
(output_dir / "eval").mkdir(exist_ok=True)
|
| 731 |
+
if "bbox" in coco_evaluator.coco_eval:
|
| 732 |
+
filenames = ["latest.pth"]
|
| 733 |
+
if epoch % 50 == 0:
|
| 734 |
+
filenames.append(f"{epoch:03}.pth")
|
| 735 |
+
for name in filenames:
|
| 736 |
+
torch.save(
|
| 737 |
+
coco_evaluator.coco_eval["bbox"].eval,
|
| 738 |
+
output_dir / "eval" / name,
|
| 739 |
+
)
|
| 740 |
+
torch.cuda.empty_cache()
|
| 741 |
+
|
| 742 |
+
total_time = time.time() - start_time
|
| 743 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 744 |
+
print("Training time {}".format(total_time_str))
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
if __name__ == "__main__":
|
| 748 |
+
parser = argparse.ArgumentParser(
|
| 749 |
+
"Deformable DETR training and evaluation script", parents=[get_args_parser()]
|
| 750 |
+
)
|
| 751 |
+
args = parser.parse_args()
|
| 752 |
+
if args.output_dir:
|
| 753 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 754 |
+
main(args)
|
perception_models/apps/detection/DETA_pe/models/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
from .deformable_detr import build
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_model(args):
|
| 14 |
+
return build(args)
|
| 15 |
+
|
perception_models/apps/detection/DETA_pe/models/assigner.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# Modified by Jeffrey Ouyang-Zhang
|
| 3 |
+
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from util.box_ops import (
|
| 10 |
+
box_cxcywh_to_xyxy,
|
| 11 |
+
box_iou,
|
| 12 |
+
box_xyxy_to_cxcywh,
|
| 13 |
+
generalized_box_iou,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/wrappers.py#L100
|
| 18 |
+
def nonzero_tuple(x):
|
| 19 |
+
"""
|
| 20 |
+
A 'as_tuple=True' version of torch.nonzero to support torchscript.
|
| 21 |
+
because of https://github.com/pytorch/pytorch/issues/38718
|
| 22 |
+
"""
|
| 23 |
+
if torch.jit.is_scripting():
|
| 24 |
+
if x.dim() == 0:
|
| 25 |
+
return x.unsqueeze(0).nonzero().unbind(1)
|
| 26 |
+
return x.nonzero().unbind(1)
|
| 27 |
+
else:
|
| 28 |
+
return x.nonzero(as_tuple=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# from https://github.com/facebookresearch/detectron2/blob/9921a2caa585d4fa66c4b534b6fab6e74d89b582/detectron2/modeling/matcher.py#L9
|
| 32 |
+
class Matcher(object):
|
| 33 |
+
"""
|
| 34 |
+
This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
| 35 |
+
element. Each predicted element will have exactly zero or one matches; each
|
| 36 |
+
ground-truth element may be matched to zero or more predicted elements.
|
| 37 |
+
|
| 38 |
+
The matching is determined by the MxN match_quality_matrix, that characterizes
|
| 39 |
+
how well each (ground-truth, prediction)-pair match each other. For example,
|
| 40 |
+
if the elements are boxes, this matrix may contain box intersection-over-union
|
| 41 |
+
overlap values.
|
| 42 |
+
|
| 43 |
+
The matcher returns (a) a vector of length N containing the index of the
|
| 44 |
+
ground-truth element m in [0, M) that matches to prediction n in [0, N).
|
| 45 |
+
(b) a vector of length N containing the labels for each prediction.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
thresholds: List[float],
|
| 51 |
+
labels: List[int],
|
| 52 |
+
allow_low_quality_matches: bool = False,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
thresholds (list): a list of thresholds used to stratify predictions
|
| 57 |
+
into levels.
|
| 58 |
+
labels (list): a list of values to label predictions belonging at
|
| 59 |
+
each level. A label can be one of {-1, 0, 1} signifying
|
| 60 |
+
{ignore, negative class, positive class}, respectively.
|
| 61 |
+
allow_low_quality_matches (bool): if True, produce additional matches
|
| 62 |
+
for predictions with maximum match quality lower than high_threshold.
|
| 63 |
+
See set_low_quality_matches_ for more details.
|
| 64 |
+
|
| 65 |
+
For example,
|
| 66 |
+
thresholds = [0.3, 0.5]
|
| 67 |
+
labels = [0, -1, 1]
|
| 68 |
+
All predictions with iou < 0.3 will be marked with 0 and
|
| 69 |
+
thus will be considered as false positives while training.
|
| 70 |
+
All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
|
| 71 |
+
thus will be ignored.
|
| 72 |
+
All predictions with 0.5 <= iou will be marked with 1 and
|
| 73 |
+
thus will be considered as true positives.
|
| 74 |
+
"""
|
| 75 |
+
# Add -inf and +inf to first and last position in thresholds
|
| 76 |
+
thresholds = thresholds[:]
|
| 77 |
+
assert thresholds[0] > 0
|
| 78 |
+
thresholds.insert(0, -float("inf"))
|
| 79 |
+
thresholds.append(float("inf"))
|
| 80 |
+
# Currently torchscript does not support all + generator
|
| 81 |
+
assert all(
|
| 82 |
+
[low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]
|
| 83 |
+
), thresholds
|
| 84 |
+
assert all([l in [-1, 0, 1] for l in labels])
|
| 85 |
+
assert len(labels) == len(thresholds) - 1
|
| 86 |
+
self.thresholds = thresholds
|
| 87 |
+
self.labels = labels
|
| 88 |
+
self.allow_low_quality_matches = allow_low_quality_matches
|
| 89 |
+
|
| 90 |
+
def __call__(self, match_quality_matrix):
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
|
| 94 |
+
pairwise quality between M ground-truth elements and N predicted
|
| 95 |
+
elements. All elements must be >= 0 (due to the us of `torch.nonzero`
|
| 96 |
+
for selecting indices in :meth:`set_low_quality_matches_`).
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
|
| 100 |
+
ground-truth index in [0, M)
|
| 101 |
+
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
|
| 102 |
+
whether a prediction is a true or false positive or ignored
|
| 103 |
+
"""
|
| 104 |
+
assert match_quality_matrix.dim() == 2
|
| 105 |
+
if match_quality_matrix.numel() == 0:
|
| 106 |
+
default_matches = match_quality_matrix.new_full(
|
| 107 |
+
(match_quality_matrix.size(1),), 0, dtype=torch.int64
|
| 108 |
+
)
|
| 109 |
+
# When no gt boxes exist, we define IOU = 0 and therefore set labels
|
| 110 |
+
# to `self.labels[0]`, which usually defaults to background class 0
|
| 111 |
+
# To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
|
| 112 |
+
default_match_labels = match_quality_matrix.new_full(
|
| 113 |
+
(match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
|
| 114 |
+
)
|
| 115 |
+
return default_matches, default_match_labels
|
| 116 |
+
|
| 117 |
+
assert torch.all(match_quality_matrix >= 0)
|
| 118 |
+
|
| 119 |
+
# match_quality_matrix is M (gt) x N (predicted)
|
| 120 |
+
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
| 121 |
+
matched_vals, matches = match_quality_matrix.max(dim=0)
|
| 122 |
+
|
| 123 |
+
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
| 124 |
+
|
| 125 |
+
for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
| 126 |
+
low_high = (matched_vals >= low) & (matched_vals < high)
|
| 127 |
+
match_labels[low_high] = l
|
| 128 |
+
|
| 129 |
+
if self.allow_low_quality_matches:
|
| 130 |
+
self.set_low_quality_matches_(match_labels, match_quality_matrix)
|
| 131 |
+
|
| 132 |
+
return matches, match_labels
|
| 133 |
+
|
| 134 |
+
def set_low_quality_matches_(self, match_labels, match_quality_matrix):
|
| 135 |
+
"""
|
| 136 |
+
Produce additional matches for predictions that have only low-quality matches.
|
| 137 |
+
Specifically, for each ground-truth G find the set of predictions that have
|
| 138 |
+
maximum overlap with it (including ties); for each prediction in that set, if
|
| 139 |
+
it is unmatched, then match it to the ground-truth G.
|
| 140 |
+
|
| 141 |
+
This function implements the RPN assignment case (i) in Sec. 3.1.2 of
|
| 142 |
+
:paper:`Faster R-CNN`.
|
| 143 |
+
"""
|
| 144 |
+
# For each gt, find the prediction with which it has highest quality
|
| 145 |
+
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
| 146 |
+
# Find the highest quality match available, even if it is low, including ties.
|
| 147 |
+
# Note that the matches qualities must be positive due to the use of
|
| 148 |
+
# `torch.nonzero`.
|
| 149 |
+
_, pred_inds_with_highest_quality = nonzero_tuple(
|
| 150 |
+
match_quality_matrix == highest_quality_foreach_gt[:, None]
|
| 151 |
+
)
|
| 152 |
+
# If an anchor was labeled positive only due to a low-quality match
|
| 153 |
+
# with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
|
| 154 |
+
# This follows the implementation in Detectron, and is found to have no significant impact.
|
| 155 |
+
match_labels[pred_inds_with_highest_quality] = 1
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9
|
| 159 |
+
def subsample_labels(
|
| 160 |
+
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
|
| 161 |
+
):
|
| 162 |
+
"""
|
| 163 |
+
Return `num_samples` (or fewer, if not enough found)
|
| 164 |
+
random samples from `labels` which is a mixture of positives & negatives.
|
| 165 |
+
It will try to return as many positives as possible without
|
| 166 |
+
exceeding `positive_fraction * num_samples`, and then try to
|
| 167 |
+
fill the remaining slots with negatives.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
labels (Tensor): (N, ) label vector with values:
|
| 171 |
+
* -1: ignore
|
| 172 |
+
* bg_label: background ("negative") class
|
| 173 |
+
* otherwise: one or more foreground ("positive") classes
|
| 174 |
+
num_samples (int): The total number of labels with value >= 0 to return.
|
| 175 |
+
Values that are not sampled will be filled with -1 (ignore).
|
| 176 |
+
positive_fraction (float): The number of subsampled labels with values > 0
|
| 177 |
+
is `min(num_positives, int(positive_fraction * num_samples))`. The number
|
| 178 |
+
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
|
| 179 |
+
In order words, if there are not enough positives, the sample is filled with
|
| 180 |
+
negatives. If there are also not enough negatives, then as many elements are
|
| 181 |
+
sampled as is possible.
|
| 182 |
+
bg_label (int): label index of background ("negative") class.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
pos_idx, neg_idx (Tensor):
|
| 186 |
+
1D vector of indices. The total length of both is `num_samples` or fewer.
|
| 187 |
+
"""
|
| 188 |
+
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
|
| 189 |
+
negative = nonzero_tuple(labels == bg_label)[0]
|
| 190 |
+
|
| 191 |
+
num_pos = int(num_samples * positive_fraction)
|
| 192 |
+
# protect against not enough positive examples
|
| 193 |
+
num_pos = min(positive.numel(), num_pos)
|
| 194 |
+
num_neg = num_samples - num_pos
|
| 195 |
+
# protect against not enough negative examples
|
| 196 |
+
num_neg = min(negative.numel(), num_neg)
|
| 197 |
+
|
| 198 |
+
# randomly select positive and negative examples
|
| 199 |
+
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
| 200 |
+
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
| 201 |
+
|
| 202 |
+
pos_idx = positive[perm1]
|
| 203 |
+
neg_idx = negative[perm2]
|
| 204 |
+
return pos_idx, neg_idx
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def sample_topk_per_gt(pr_inds, gt_inds, iou, k):
|
| 208 |
+
if len(gt_inds) == 0:
|
| 209 |
+
return pr_inds, gt_inds
|
| 210 |
+
# find topk matches for each gt
|
| 211 |
+
gt_inds2, counts = gt_inds.unique(return_counts=True)
|
| 212 |
+
scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1)
|
| 213 |
+
gt_inds2 = gt_inds2[:, None].repeat(1, k)
|
| 214 |
+
|
| 215 |
+
# filter to as many matches that gt has
|
| 216 |
+
pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
|
| 217 |
+
gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
|
| 218 |
+
return pr_inds3, gt_inds3
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123
|
| 222 |
+
class Stage2Assigner(nn.Module):
|
| 223 |
+
def __init__(self, num_queries, max_k=4):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.positive_fraction = 0.25
|
| 226 |
+
self.bg_label = 400 # number > 91 to filter out later
|
| 227 |
+
self.batch_size_per_image = num_queries
|
| 228 |
+
self.proposal_matcher = Matcher(
|
| 229 |
+
thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True
|
| 230 |
+
)
|
| 231 |
+
self.k = max_k
|
| 232 |
+
|
| 233 |
+
def _sample_proposals(
|
| 234 |
+
self,
|
| 235 |
+
matched_idxs: torch.Tensor,
|
| 236 |
+
matched_labels: torch.Tensor,
|
| 237 |
+
gt_classes: torch.Tensor,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Based on the matching between N proposals and M groundtruth,
|
| 241 |
+
sample the proposals and set their classification labels.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
matched_idxs (Tensor): a vector of length N, each is the best-matched
|
| 245 |
+
gt index in [0, M) for each proposal.
|
| 246 |
+
matched_labels (Tensor): a vector of length N, the matcher's label
|
| 247 |
+
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
|
| 248 |
+
gt_classes (Tensor): a vector of length M.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
|
| 252 |
+
Tensor: a vector of the same length, the classification label for
|
| 253 |
+
each sampled proposal. Each sample is labeled as either a category in
|
| 254 |
+
[0, num_classes) or the background (num_classes).
|
| 255 |
+
"""
|
| 256 |
+
has_gt = gt_classes.numel() > 0
|
| 257 |
+
# Get the corresponding GT for each proposal
|
| 258 |
+
if has_gt:
|
| 259 |
+
gt_classes = gt_classes[matched_idxs]
|
| 260 |
+
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
|
| 261 |
+
gt_classes[matched_labels == 0] = self.bg_label
|
| 262 |
+
# Label ignore proposals (-1 label)
|
| 263 |
+
gt_classes[matched_labels == -1] = -1
|
| 264 |
+
else:
|
| 265 |
+
gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
|
| 266 |
+
|
| 267 |
+
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
|
| 268 |
+
gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
|
| 272 |
+
return sampled_idxs, gt_classes[sampled_idxs]
|
| 273 |
+
|
| 274 |
+
def forward(self, outputs, targets, return_cost_matrix=False):
|
| 275 |
+
# COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid.
|
| 276 |
+
|
| 277 |
+
bs = len(targets)
|
| 278 |
+
indices = []
|
| 279 |
+
ious = []
|
| 280 |
+
for b in range(bs):
|
| 281 |
+
iou, _ = box_iou(
|
| 282 |
+
box_cxcywh_to_xyxy(targets[b]["boxes"]),
|
| 283 |
+
box_cxcywh_to_xyxy(outputs["init_reference"][b].detach()),
|
| 284 |
+
)
|
| 285 |
+
matched_idxs, matched_labels = self.proposal_matcher(
|
| 286 |
+
iou
|
| 287 |
+
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow]
|
| 288 |
+
sampled_idxs, sampled_gt_classes = (
|
| 289 |
+
self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
|
| 290 |
+
matched_idxs, matched_labels, targets[b]["labels"]
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
+
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
|
| 294 |
+
pos_gt_inds = matched_idxs[pos_pr_inds]
|
| 295 |
+
pos_pr_inds, pos_gt_inds = self.postprocess_indices(
|
| 296 |
+
pos_pr_inds, pos_gt_inds, iou
|
| 297 |
+
)
|
| 298 |
+
indices.append((pos_pr_inds, pos_gt_inds))
|
| 299 |
+
ious.append(iou)
|
| 300 |
+
if return_cost_matrix:
|
| 301 |
+
return indices, ious
|
| 302 |
+
return indices
|
| 303 |
+
|
| 304 |
+
def postprocess_indices(self, pr_inds, gt_inds, iou):
|
| 305 |
+
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181
|
| 309 |
+
class Stage1Assigner(nn.Module):
|
| 310 |
+
def __init__(self, t_low=0.3, t_high=0.7, max_k=4):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.positive_fraction = 0.5
|
| 313 |
+
self.batch_size_per_image = 256
|
| 314 |
+
self.k = max_k
|
| 315 |
+
self.t_low = t_low
|
| 316 |
+
self.t_high = t_high
|
| 317 |
+
self.anchor_matcher = Matcher(
|
| 318 |
+
thresholds=[t_low, t_high],
|
| 319 |
+
labels=[0, -1, 1],
|
| 320 |
+
allow_low_quality_matches=True,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def _subsample_labels(self, label):
|
| 324 |
+
"""
|
| 325 |
+
Randomly sample a subset of positive and negative examples, and overwrite
|
| 326 |
+
the label vector to the ignore value (-1) for all elements that are not
|
| 327 |
+
included in the sample.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
|
| 331 |
+
"""
|
| 332 |
+
pos_idx, neg_idx = subsample_labels(
|
| 333 |
+
label, self.batch_size_per_image, self.positive_fraction, 0
|
| 334 |
+
)
|
| 335 |
+
# Fill with the ignore label (-1), then set positive and negative labels
|
| 336 |
+
label.fill_(-1)
|
| 337 |
+
label.scatter_(0, pos_idx, 1)
|
| 338 |
+
label.scatter_(0, neg_idx, 0)
|
| 339 |
+
return label
|
| 340 |
+
|
| 341 |
+
def forward(self, outputs, targets):
|
| 342 |
+
bs = len(targets)
|
| 343 |
+
indices = []
|
| 344 |
+
for b in range(bs):
|
| 345 |
+
anchors = outputs["anchors"][b]
|
| 346 |
+
if len(targets[b]["boxes"]) == 0:
|
| 347 |
+
indices.append(
|
| 348 |
+
(
|
| 349 |
+
torch.tensor([], dtype=torch.long, device=anchors.device),
|
| 350 |
+
torch.tensor([], dtype=torch.long, device=anchors.device),
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
continue
|
| 354 |
+
iou, _ = box_iou(
|
| 355 |
+
box_cxcywh_to_xyxy(targets[b]["boxes"]),
|
| 356 |
+
box_cxcywh_to_xyxy(anchors),
|
| 357 |
+
)
|
| 358 |
+
matched_idxs, matched_labels = self.anchor_matcher(
|
| 359 |
+
iou
|
| 360 |
+
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
|
| 361 |
+
matched_labels = self._subsample_labels(matched_labels)
|
| 362 |
+
|
| 363 |
+
all_pr_inds = torch.arange(len(anchors)).to(anchors.device)
|
| 364 |
+
|
| 365 |
+
pos_pr_inds = all_pr_inds[matched_labels == 1]
|
| 366 |
+
pos_gt_inds = matched_idxs[pos_pr_inds]
|
| 367 |
+
pos_ious = iou[pos_gt_inds, pos_pr_inds]
|
| 368 |
+
pos_pr_inds, pos_gt_inds = self.postprocess_indices(
|
| 369 |
+
pos_pr_inds, pos_gt_inds, iou
|
| 370 |
+
)
|
| 371 |
+
pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(
|
| 372 |
+
anchors.device
|
| 373 |
+
)
|
| 374 |
+
indices.append((pos_pr_inds, pos_gt_inds))
|
| 375 |
+
return indices
|
| 376 |
+
|
| 377 |
+
def postprocess_indices(self, pr_inds, gt_inds, iou):
|
| 378 |
+
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
|
perception_models/apps/detection/DETA_pe/models/backbone.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Backbone modules.
|
| 12 |
+
"""
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Dict, List
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torchvision
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.cuda.amp import autocast
|
| 22 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
| 23 |
+
from util.misc import is_main_process, NestedTensor
|
| 24 |
+
|
| 25 |
+
from .position_encoding import build_position_encoding
|
| 26 |
+
from .swin import get_swinl
|
| 27 |
+
from .pev1 import get_pev1_and_fpn_backbone
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 33 |
+
|
| 34 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
| 35 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
| 36 |
+
produce nans.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, n, eps=1e-5):
|
| 40 |
+
super(FrozenBatchNorm2d, self).__init__()
|
| 41 |
+
self.register_buffer("weight", torch.ones(n))
|
| 42 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 43 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 44 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 45 |
+
self.eps = eps
|
| 46 |
+
|
| 47 |
+
def _load_from_state_dict(
|
| 48 |
+
self,
|
| 49 |
+
state_dict,
|
| 50 |
+
prefix,
|
| 51 |
+
local_metadata,
|
| 52 |
+
strict,
|
| 53 |
+
missing_keys,
|
| 54 |
+
unexpected_keys,
|
| 55 |
+
error_msgs,
|
| 56 |
+
):
|
| 57 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 58 |
+
if num_batches_tracked_key in state_dict:
|
| 59 |
+
del state_dict[num_batches_tracked_key]
|
| 60 |
+
|
| 61 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
| 62 |
+
state_dict,
|
| 63 |
+
prefix,
|
| 64 |
+
local_metadata,
|
| 65 |
+
strict,
|
| 66 |
+
missing_keys,
|
| 67 |
+
unexpected_keys,
|
| 68 |
+
error_msgs,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
# move reshapes to the beginning
|
| 73 |
+
# to make it fuser-friendly
|
| 74 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
| 75 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
| 76 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
| 77 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
| 78 |
+
eps = self.eps
|
| 79 |
+
scale = w * (rv + eps).rsqrt()
|
| 80 |
+
bias = b - rm * scale
|
| 81 |
+
return x * scale + bias
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BackboneBase(nn.Module):
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
for name, parameter in backbone.named_parameters():
|
| 91 |
+
if (
|
| 92 |
+
not train_backbone
|
| 93 |
+
or "layer2" not in name
|
| 94 |
+
and "layer3" not in name
|
| 95 |
+
and "layer4" not in name
|
| 96 |
+
):
|
| 97 |
+
parameter.requires_grad_(False)
|
| 98 |
+
if return_interm_layers:
|
| 99 |
+
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
| 100 |
+
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
| 101 |
+
self.strides = [8, 16, 32]
|
| 102 |
+
self.num_channels = [512, 1024, 2048]
|
| 103 |
+
else:
|
| 104 |
+
return_layers = {"layer4": "0"}
|
| 105 |
+
self.strides = [32]
|
| 106 |
+
self.num_channels = [2048]
|
| 107 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 108 |
+
|
| 109 |
+
def forward(self, tensor_list: NestedTensor):
|
| 110 |
+
xs = self.body(tensor_list.tensors)
|
| 111 |
+
out: Dict[str, NestedTensor] = {}
|
| 112 |
+
for name, x in xs.items():
|
| 113 |
+
m = tensor_list.mask
|
| 114 |
+
assert m is not None
|
| 115 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
| 116 |
+
out[name] = NestedTensor(x, mask)
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Backbone(BackboneBase):
|
| 121 |
+
"""ResNet backbone with frozen BatchNorm."""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
name: str,
|
| 126 |
+
train_backbone: bool,
|
| 127 |
+
return_interm_layers: bool,
|
| 128 |
+
dilation: bool,
|
| 129 |
+
):
|
| 130 |
+
norm_layer = FrozenBatchNorm2d
|
| 131 |
+
backbone = getattr(torchvision.models, name)(
|
| 132 |
+
replace_stride_with_dilation=[False, False, dilation],
|
| 133 |
+
pretrained=is_main_process(),
|
| 134 |
+
norm_layer=norm_layer,
|
| 135 |
+
)
|
| 136 |
+
assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
|
| 137 |
+
super().__init__(backbone, train_backbone, return_interm_layers)
|
| 138 |
+
if dilation:
|
| 139 |
+
self.strides[-1] = self.strides[-1] // 2
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SwinBackbone(nn.Module):
|
| 143 |
+
def __init__(self):
|
| 144 |
+
# we skip R50 FrozenBatchNorm2d, dilation, train l{2,3,4} only
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.body = get_swinl()
|
| 147 |
+
self.features = ["res3", "res4", "res5"]
|
| 148 |
+
self.strides = [8, 16, 32]
|
| 149 |
+
self.num_channels = [384, 768, 1536]
|
| 150 |
+
|
| 151 |
+
def forward(self, tensor_list: NestedTensor):
|
| 152 |
+
xs = self.body(tensor_list.tensors)
|
| 153 |
+
m = tensor_list.mask[None]
|
| 154 |
+
assert m is not None
|
| 155 |
+
out: Dict[str, NestedTensor] = {}
|
| 156 |
+
for name in self.features:
|
| 157 |
+
mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
|
| 158 |
+
out[name] = NestedTensor(xs[name], mask)
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class PEv1Backbone(nn.Module):
|
| 163 |
+
def __init__(self, args):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.body = get_pev1_and_fpn_backbone(args)
|
| 166 |
+
self.features = self.body._out_features
|
| 167 |
+
|
| 168 |
+
self.bf16 = args.bf16
|
| 169 |
+
self.fp16 = args.fp16
|
| 170 |
+
|
| 171 |
+
_out_feature_strides = self.body._out_feature_strides
|
| 172 |
+
_out_feature_channels = self.body._out_feature_channels
|
| 173 |
+
self.strides = [_out_feature_strides[f] for f in _out_feature_strides.keys()]
|
| 174 |
+
self.num_channels = [
|
| 175 |
+
_out_feature_channels[f] for f in _out_feature_channels.keys()
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
def forward(self, tensor_list: NestedTensor):
|
| 179 |
+
# xs = self.body(tensor_list.tensors)
|
| 180 |
+
# backbone
|
| 181 |
+
if self.bf16:
|
| 182 |
+
with autocast(dtype=torch.bfloat16):
|
| 183 |
+
xs = self.body(tensor_list.tensors.to(torch.bfloat16))
|
| 184 |
+
xs = {k: v.float() for k, v in xs.items()}
|
| 185 |
+
elif self.fp16:
|
| 186 |
+
with autocast(dtype=torch.float16):
|
| 187 |
+
xs = self.body(tensor_list.tensors.half())
|
| 188 |
+
xs = {k: v.float() for k, v in xs.items()}
|
| 189 |
+
else:
|
| 190 |
+
xs = self.body(tensor_list.tensors)
|
| 191 |
+
|
| 192 |
+
m = tensor_list.mask[None]
|
| 193 |
+
assert m is not None
|
| 194 |
+
out: Dict[str, NestedTensor] = {}
|
| 195 |
+
|
| 196 |
+
for name in self.features:
|
| 197 |
+
mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
|
| 198 |
+
out[name] = NestedTensor(xs[name], mask)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Joiner(nn.Sequential):
|
| 203 |
+
def __init__(self, backbone, position_embedding):
|
| 204 |
+
super().__init__(backbone, position_embedding)
|
| 205 |
+
self.strides = backbone.strides
|
| 206 |
+
self.num_channels = backbone.num_channels
|
| 207 |
+
|
| 208 |
+
def forward(self, tensor_list: NestedTensor):
|
| 209 |
+
xs = self[0](tensor_list)
|
| 210 |
+
out: List[NestedTensor] = []
|
| 211 |
+
pos = []
|
| 212 |
+
for name, x in sorted(xs.items()):
|
| 213 |
+
out.append(x)
|
| 214 |
+
|
| 215 |
+
# position encoding
|
| 216 |
+
for x in out:
|
| 217 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
| 218 |
+
|
| 219 |
+
return out, pos
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def build_backbone(args):
|
| 223 |
+
position_embedding = build_position_encoding(args)
|
| 224 |
+
train_backbone = args.lr_backbone > 0
|
| 225 |
+
return_interm_layers = args.masks or (args.num_feature_levels > 1)
|
| 226 |
+
if "swin" in args.backbone:
|
| 227 |
+
backbone = SwinBackbone()
|
| 228 |
+
elif "pev1" in args.backbone:
|
| 229 |
+
backbone = PEv1Backbone(args)
|
| 230 |
+
else:
|
| 231 |
+
backbone = Backbone(
|
| 232 |
+
args.backbone, train_backbone, return_interm_layers, args.dilation
|
| 233 |
+
)
|
| 234 |
+
model = Joiner(backbone, position_embedding)
|
| 235 |
+
return model
|
perception_models/apps/detection/DETA_pe/models/deformable_detr.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Deformable DETR model and criterion classes.
|
| 12 |
+
"""
|
| 13 |
+
import copy
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torchvision.ops.boxes import batched_nms
|
| 20 |
+
|
| 21 |
+
from util import box_ops
|
| 22 |
+
from util.misc import (
|
| 23 |
+
accuracy,
|
| 24 |
+
get_world_size,
|
| 25 |
+
interpolate,
|
| 26 |
+
inverse_sigmoid,
|
| 27 |
+
is_dist_avail_and_initialized,
|
| 28 |
+
nested_tensor_from_tensor_list,
|
| 29 |
+
NestedTensor,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .assigner import Stage1Assigner, Stage2Assigner
|
| 33 |
+
|
| 34 |
+
from .backbone import build_backbone
|
| 35 |
+
from .deformable_transformer import build_deforamble_transformer
|
| 36 |
+
from .matcher import build_matcher
|
| 37 |
+
from .segmentation import (
|
| 38 |
+
DETRsegm,
|
| 39 |
+
dice_loss,
|
| 40 |
+
PostProcessPanoptic,
|
| 41 |
+
PostProcessSegm,
|
| 42 |
+
sigmoid_focal_loss,
|
| 43 |
+
)
|
| 44 |
+
from .utils_fed_loss import get_fed_loss_inds, load_class_freq
|
| 45 |
+
from .utils_softnms import batched_soft_nms
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_clones(module, N):
|
| 49 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DeformableDETR(nn.Module):
|
| 53 |
+
"""This is the Deformable DETR module that performs object detection"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
backbone,
|
| 58 |
+
transformer,
|
| 59 |
+
num_classes,
|
| 60 |
+
num_queries,
|
| 61 |
+
num_feature_levels,
|
| 62 |
+
aux_loss=True,
|
| 63 |
+
with_box_refine=False,
|
| 64 |
+
two_stage=False,
|
| 65 |
+
):
|
| 66 |
+
"""Initializes the model.
|
| 67 |
+
Parameters:
|
| 68 |
+
backbone: torch module of the backbone to be used. See backbone.py
|
| 69 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
| 70 |
+
num_classes: number of object classes
|
| 71 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
| 72 |
+
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
| 73 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 74 |
+
with_box_refine: iterative bounding box refinement
|
| 75 |
+
two_stage: two-stage Deformable DETR
|
| 76 |
+
"""
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.num_queries = num_queries
|
| 79 |
+
self.transformer = transformer
|
| 80 |
+
hidden_dim = transformer.d_model
|
| 81 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes)
|
| 82 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 83 |
+
self.num_feature_levels = num_feature_levels
|
| 84 |
+
if not two_stage:
|
| 85 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
|
| 86 |
+
if num_feature_levels > 1:
|
| 87 |
+
num_backbone_outs = len(backbone.strides)
|
| 88 |
+
input_proj_list = []
|
| 89 |
+
for _ in range(num_backbone_outs):
|
| 90 |
+
in_channels = backbone.num_channels[_]
|
| 91 |
+
input_proj_list.append(
|
| 92 |
+
nn.Sequential(
|
| 93 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
| 94 |
+
nn.GroupNorm(32, hidden_dim),
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
for _ in range(num_feature_levels - num_backbone_outs):
|
| 98 |
+
input_proj_list.append(
|
| 99 |
+
nn.Sequential(
|
| 100 |
+
nn.Conv2d(
|
| 101 |
+
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
|
| 102 |
+
),
|
| 103 |
+
nn.GroupNorm(32, hidden_dim),
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
in_channels = hidden_dim
|
| 107 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
| 108 |
+
else:
|
| 109 |
+
self.input_proj = nn.ModuleList(
|
| 110 |
+
[
|
| 111 |
+
nn.Sequential(
|
| 112 |
+
nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
|
| 113 |
+
nn.GroupNorm(32, hidden_dim),
|
| 114 |
+
)
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
self.backbone = backbone
|
| 118 |
+
self.aux_loss = aux_loss
|
| 119 |
+
self.with_box_refine = with_box_refine
|
| 120 |
+
self.two_stage = two_stage
|
| 121 |
+
|
| 122 |
+
prior_prob = 0.01
|
| 123 |
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
| 124 |
+
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
|
| 125 |
+
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
| 126 |
+
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
| 127 |
+
for proj in self.input_proj:
|
| 128 |
+
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
| 129 |
+
nn.init.constant_(proj[0].bias, 0)
|
| 130 |
+
|
| 131 |
+
# if two-stage, the last class_embed and bbox_embed is for region proposal generation
|
| 132 |
+
num_pred = (
|
| 133 |
+
(transformer.decoder.num_layers + 1)
|
| 134 |
+
if two_stage
|
| 135 |
+
else transformer.decoder.num_layers
|
| 136 |
+
)
|
| 137 |
+
if with_box_refine:
|
| 138 |
+
self.class_embed = _get_clones(self.class_embed, num_pred)
|
| 139 |
+
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
|
| 140 |
+
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
|
| 141 |
+
# hack implementation for iterative bounding box refinement
|
| 142 |
+
self.transformer.decoder.bbox_embed = self.bbox_embed
|
| 143 |
+
else:
|
| 144 |
+
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
|
| 145 |
+
self.class_embed = nn.ModuleList(
|
| 146 |
+
[self.class_embed for _ in range(num_pred)]
|
| 147 |
+
)
|
| 148 |
+
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
|
| 149 |
+
self.transformer.decoder.bbox_embed = None
|
| 150 |
+
if two_stage:
|
| 151 |
+
# hack implementation for two-stage
|
| 152 |
+
self.transformer.decoder.class_embed = self.class_embed
|
| 153 |
+
for box_embed in self.bbox_embed:
|
| 154 |
+
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
|
| 155 |
+
|
| 156 |
+
def forward(self, samples: NestedTensor):
|
| 157 |
+
"""The forward expects a NestedTensor, which consists of:
|
| 158 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
| 159 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
| 160 |
+
|
| 161 |
+
It returns a dict with the following elements:
|
| 162 |
+
- "pred_logits": the classification logits (including no-object) for all queries.
|
| 163 |
+
Shape= [batch_size x num_queries x (num_classes + 1)]
|
| 164 |
+
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
| 165 |
+
(center_x, center_y, height, width). These values are normalized in [0, 1],
|
| 166 |
+
relative to the size of each individual image (disregarding possible padding).
|
| 167 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
| 168 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
| 169 |
+
dictionnaries containing the two above keys for each decoder layer.
|
| 170 |
+
"""
|
| 171 |
+
if not isinstance(samples, NestedTensor):
|
| 172 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 173 |
+
features, pos = self.backbone(samples)
|
| 174 |
+
|
| 175 |
+
srcs = []
|
| 176 |
+
masks = []
|
| 177 |
+
for l, feat in enumerate(features):
|
| 178 |
+
src, mask = feat.decompose()
|
| 179 |
+
srcs.append(self.input_proj[l](src))
|
| 180 |
+
masks.append(mask)
|
| 181 |
+
assert mask is not None
|
| 182 |
+
if self.num_feature_levels > len(srcs):
|
| 183 |
+
_len_srcs = len(srcs)
|
| 184 |
+
for l in range(_len_srcs, self.num_feature_levels):
|
| 185 |
+
if l == _len_srcs:
|
| 186 |
+
src = self.input_proj[l](features[-1].tensors)
|
| 187 |
+
else:
|
| 188 |
+
src = self.input_proj[l](srcs[-1])
|
| 189 |
+
m = samples.mask
|
| 190 |
+
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
|
| 191 |
+
torch.bool
|
| 192 |
+
)[0]
|
| 193 |
+
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
| 194 |
+
srcs.append(src)
|
| 195 |
+
masks.append(mask)
|
| 196 |
+
pos.append(pos_l)
|
| 197 |
+
|
| 198 |
+
query_embeds = None
|
| 199 |
+
if not self.two_stage:
|
| 200 |
+
query_embeds = self.query_embed.weight
|
| 201 |
+
(
|
| 202 |
+
hs,
|
| 203 |
+
init_reference,
|
| 204 |
+
inter_references,
|
| 205 |
+
enc_outputs_class,
|
| 206 |
+
enc_outputs_coord_unact,
|
| 207 |
+
anchors,
|
| 208 |
+
) = self.transformer(srcs, masks, pos, query_embeds)
|
| 209 |
+
|
| 210 |
+
outputs_classes = []
|
| 211 |
+
outputs_coords = []
|
| 212 |
+
for lvl in range(hs.shape[0]):
|
| 213 |
+
if lvl == 0:
|
| 214 |
+
reference = init_reference
|
| 215 |
+
else:
|
| 216 |
+
reference = inter_references[lvl - 1]
|
| 217 |
+
reference = inverse_sigmoid(reference)
|
| 218 |
+
outputs_class = self.class_embed[lvl](hs[lvl])
|
| 219 |
+
tmp = self.bbox_embed[lvl](hs[lvl])
|
| 220 |
+
if reference.shape[-1] == 4:
|
| 221 |
+
tmp += reference
|
| 222 |
+
else:
|
| 223 |
+
assert reference.shape[-1] == 2
|
| 224 |
+
tmp[..., :2] += reference
|
| 225 |
+
outputs_coord = tmp.sigmoid()
|
| 226 |
+
outputs_classes.append(outputs_class)
|
| 227 |
+
outputs_coords.append(outputs_coord)
|
| 228 |
+
outputs_class = torch.stack(outputs_classes)
|
| 229 |
+
outputs_coord = torch.stack(outputs_coords)
|
| 230 |
+
|
| 231 |
+
out = {
|
| 232 |
+
"pred_logits": outputs_class[-1],
|
| 233 |
+
"pred_boxes": outputs_coord[-1],
|
| 234 |
+
"init_reference": init_reference,
|
| 235 |
+
}
|
| 236 |
+
if self.aux_loss:
|
| 237 |
+
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
|
| 238 |
+
|
| 239 |
+
if self.two_stage:
|
| 240 |
+
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
|
| 241 |
+
out["enc_outputs"] = {
|
| 242 |
+
"pred_logits": enc_outputs_class,
|
| 243 |
+
"pred_boxes": enc_outputs_coord,
|
| 244 |
+
"anchors": anchors,
|
| 245 |
+
}
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
@torch.jit.unused
|
| 249 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 250 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 251 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 252 |
+
# as a dict having both a Tensor and a list.
|
| 253 |
+
return [
|
| 254 |
+
{"pred_logits": a, "pred_boxes": b}
|
| 255 |
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class SetCriterion(nn.Module):
|
| 260 |
+
"""This class computes the loss for DETR.
|
| 261 |
+
The process happens in two steps:
|
| 262 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
| 263 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
num_classes,
|
| 269 |
+
matcher,
|
| 270 |
+
weight_dict,
|
| 271 |
+
losses,
|
| 272 |
+
focal_alpha=0.25,
|
| 273 |
+
num_queries=300,
|
| 274 |
+
assign_first_stage=False,
|
| 275 |
+
assign_second_stage=False,
|
| 276 |
+
use_fed_loss=False,
|
| 277 |
+
):
|
| 278 |
+
"""Create the criterion.
|
| 279 |
+
Parameters:
|
| 280 |
+
num_classes: number of object categories, omitting the special no-object category
|
| 281 |
+
matcher: module able to compute a matching between targets and proposals
|
| 282 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
| 283 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
| 284 |
+
focal_alpha: alpha in Focal Loss
|
| 285 |
+
"""
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.num_classes = num_classes
|
| 288 |
+
self.matcher = matcher
|
| 289 |
+
self.weight_dict = weight_dict
|
| 290 |
+
self.losses = losses
|
| 291 |
+
self.focal_alpha = focal_alpha
|
| 292 |
+
self.assign_first_stage = assign_first_stage
|
| 293 |
+
self.assign_second_stage = assign_second_stage
|
| 294 |
+
|
| 295 |
+
if self.assign_first_stage:
|
| 296 |
+
self.stg1_assigner = Stage1Assigner()
|
| 297 |
+
if self.assign_second_stage:
|
| 298 |
+
self.stg2_assigner = Stage2Assigner(num_queries)
|
| 299 |
+
|
| 300 |
+
self.use_fed_loss = use_fed_loss
|
| 301 |
+
if self.use_fed_loss:
|
| 302 |
+
print("Using federated loss")
|
| 303 |
+
print("Using federated loss")
|
| 304 |
+
print("Using federated loss")
|
| 305 |
+
self.register_buffer("fed_loss_weight", load_class_freq(freq_weight=0.5))
|
| 306 |
+
|
| 307 |
+
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
| 308 |
+
"""Classification loss (NLL)
|
| 309 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
| 310 |
+
"""
|
| 311 |
+
assert "pred_logits" in outputs
|
| 312 |
+
src_logits = outputs["pred_logits"]
|
| 313 |
+
|
| 314 |
+
idx = self._get_src_permutation_idx(indices)
|
| 315 |
+
target_classes_o = torch.cat(
|
| 316 |
+
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
|
| 317 |
+
)
|
| 318 |
+
target_classes = torch.full(
|
| 319 |
+
src_logits.shape[:2],
|
| 320 |
+
self.num_classes,
|
| 321 |
+
dtype=torch.int64,
|
| 322 |
+
device=src_logits.device,
|
| 323 |
+
)
|
| 324 |
+
target_classes[idx] = target_classes_o
|
| 325 |
+
|
| 326 |
+
target_classes_onehot = torch.zeros(
|
| 327 |
+
[src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
|
| 328 |
+
dtype=src_logits.dtype,
|
| 329 |
+
layout=src_logits.layout,
|
| 330 |
+
device=src_logits.device,
|
| 331 |
+
)
|
| 332 |
+
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
| 333 |
+
|
| 334 |
+
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
| 335 |
+
if self.use_fed_loss:
|
| 336 |
+
inds = (
|
| 337 |
+
get_fed_loss_inds(
|
| 338 |
+
gt_classes=target_classes_o - 1,
|
| 339 |
+
num_sample_cats=50,
|
| 340 |
+
weight=self.fed_loss_weight,
|
| 341 |
+
C=target_classes_onehot.shape[2] - 1,
|
| 342 |
+
)
|
| 343 |
+
+ 1
|
| 344 |
+
) # pay attention to the -1 and +1
|
| 345 |
+
loss_ce = (
|
| 346 |
+
sigmoid_focal_loss(
|
| 347 |
+
src_logits[:, :, inds],
|
| 348 |
+
target_classes_onehot[:, :, inds],
|
| 349 |
+
num_boxes,
|
| 350 |
+
alpha=self.focal_alpha,
|
| 351 |
+
gamma=2,
|
| 352 |
+
)
|
| 353 |
+
* src_logits.shape[1]
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
loss_ce = (
|
| 357 |
+
sigmoid_focal_loss(
|
| 358 |
+
src_logits,
|
| 359 |
+
target_classes_onehot,
|
| 360 |
+
num_boxes,
|
| 361 |
+
alpha=self.focal_alpha,
|
| 362 |
+
gamma=2,
|
| 363 |
+
)
|
| 364 |
+
* src_logits.shape[1]
|
| 365 |
+
)
|
| 366 |
+
losses = {"loss_ce": loss_ce}
|
| 367 |
+
|
| 368 |
+
if log:
|
| 369 |
+
# TODO this should probably be a separate loss, not hacked in this one here
|
| 370 |
+
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
| 371 |
+
return losses
|
| 372 |
+
|
| 373 |
+
@torch.no_grad()
|
| 374 |
+
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
| 375 |
+
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
| 376 |
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
| 377 |
+
"""
|
| 378 |
+
pred_logits = outputs["pred_logits"]
|
| 379 |
+
device = pred_logits.device
|
| 380 |
+
tgt_lengths = torch.as_tensor(
|
| 381 |
+
[len(v["labels"]) for v in targets], device=device
|
| 382 |
+
)
|
| 383 |
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
| 384 |
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
| 385 |
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
| 386 |
+
losses = {"cardinality_error": card_err}
|
| 387 |
+
return losses
|
| 388 |
+
|
| 389 |
+
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
| 390 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
| 391 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
| 392 |
+
The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
|
| 393 |
+
"""
|
| 394 |
+
assert "pred_boxes" in outputs
|
| 395 |
+
idx = self._get_src_permutation_idx(indices)
|
| 396 |
+
src_boxes = outputs["pred_boxes"][idx]
|
| 397 |
+
target_boxes = torch.cat(
|
| 398 |
+
[t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
| 402 |
+
|
| 403 |
+
losses = {}
|
| 404 |
+
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
| 405 |
+
|
| 406 |
+
loss_giou = 1 - torch.diag(
|
| 407 |
+
box_ops.generalized_box_iou(
|
| 408 |
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
| 409 |
+
box_ops.box_cxcywh_to_xyxy(target_boxes),
|
| 410 |
+
)
|
| 411 |
+
)
|
| 412 |
+
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
| 413 |
+
return losses
|
| 414 |
+
|
| 415 |
+
def loss_masks(self, outputs, targets, indices, num_boxes):
|
| 416 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
| 417 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
| 418 |
+
"""
|
| 419 |
+
assert "pred_masks" in outputs
|
| 420 |
+
|
| 421 |
+
src_idx = self._get_src_permutation_idx(indices)
|
| 422 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
| 423 |
+
|
| 424 |
+
src_masks = outputs["pred_masks"]
|
| 425 |
+
|
| 426 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
| 427 |
+
target_masks, valid = nested_tensor_from_tensor_list(
|
| 428 |
+
[t["masks"] for t in targets]
|
| 429 |
+
).decompose()
|
| 430 |
+
target_masks = target_masks.to(src_masks)
|
| 431 |
+
|
| 432 |
+
src_masks = src_masks[src_idx]
|
| 433 |
+
# upsample predictions to the target size
|
| 434 |
+
src_masks = interpolate(
|
| 435 |
+
src_masks[:, None],
|
| 436 |
+
size=target_masks.shape[-2:],
|
| 437 |
+
mode="bilinear",
|
| 438 |
+
align_corners=False,
|
| 439 |
+
)
|
| 440 |
+
src_masks = src_masks[:, 0].flatten(1)
|
| 441 |
+
|
| 442 |
+
target_masks = target_masks[tgt_idx].flatten(1)
|
| 443 |
+
|
| 444 |
+
losses = {
|
| 445 |
+
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
| 446 |
+
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
| 447 |
+
}
|
| 448 |
+
return losses
|
| 449 |
+
|
| 450 |
+
def _get_src_permutation_idx(self, indices):
|
| 451 |
+
# permute predictions following indices
|
| 452 |
+
batch_idx = torch.cat(
|
| 453 |
+
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
|
| 454 |
+
)
|
| 455 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
| 456 |
+
return batch_idx, src_idx
|
| 457 |
+
|
| 458 |
+
def _get_tgt_permutation_idx(self, indices):
|
| 459 |
+
# permute targets following indices
|
| 460 |
+
batch_idx = torch.cat(
|
| 461 |
+
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
|
| 462 |
+
)
|
| 463 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| 464 |
+
return batch_idx, tgt_idx
|
| 465 |
+
|
| 466 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
| 467 |
+
loss_map = {
|
| 468 |
+
"labels": self.loss_labels,
|
| 469 |
+
"cardinality": self.loss_cardinality,
|
| 470 |
+
"boxes": self.loss_boxes,
|
| 471 |
+
"masks": self.loss_masks,
|
| 472 |
+
}
|
| 473 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
| 474 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
| 475 |
+
|
| 476 |
+
def forward(self, outputs, targets):
|
| 477 |
+
"""This performs the loss computation.
|
| 478 |
+
Parameters:
|
| 479 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
| 480 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
| 481 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
| 482 |
+
"""
|
| 483 |
+
outputs_without_aux = {
|
| 484 |
+
k: v
|
| 485 |
+
for k, v in outputs.items()
|
| 486 |
+
if k != "aux_outputs" and k != "enc_outputs"
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
| 490 |
+
if self.assign_second_stage:
|
| 491 |
+
indices = self.stg2_assigner(outputs_without_aux, targets)
|
| 492 |
+
else:
|
| 493 |
+
indices = self.matcher(outputs_without_aux, targets)
|
| 494 |
+
|
| 495 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
| 496 |
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
| 497 |
+
num_boxes = torch.as_tensor(
|
| 498 |
+
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
|
| 499 |
+
)
|
| 500 |
+
if is_dist_avail_and_initialized():
|
| 501 |
+
torch.distributed.all_reduce(num_boxes)
|
| 502 |
+
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
| 503 |
+
|
| 504 |
+
# Compute all the requested losses
|
| 505 |
+
losses = {}
|
| 506 |
+
for loss in self.losses:
|
| 507 |
+
kwargs = {}
|
| 508 |
+
losses.update(
|
| 509 |
+
self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
| 513 |
+
if "aux_outputs" in outputs:
|
| 514 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
| 515 |
+
if not self.assign_second_stage:
|
| 516 |
+
indices = self.matcher(aux_outputs, targets)
|
| 517 |
+
for loss in self.losses:
|
| 518 |
+
if loss == "masks":
|
| 519 |
+
# Intermediate masks losses are too costly to compute, we ignore them.
|
| 520 |
+
continue
|
| 521 |
+
kwargs = {}
|
| 522 |
+
if loss == "labels":
|
| 523 |
+
# Logging is enabled only for the last layer
|
| 524 |
+
kwargs["log"] = False
|
| 525 |
+
l_dict = self.get_loss(
|
| 526 |
+
loss, aux_outputs, targets, indices, num_boxes, **kwargs
|
| 527 |
+
)
|
| 528 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
| 529 |
+
losses.update(l_dict)
|
| 530 |
+
|
| 531 |
+
if "enc_outputs" in outputs:
|
| 532 |
+
enc_outputs = outputs["enc_outputs"]
|
| 533 |
+
bin_targets = copy.deepcopy(targets)
|
| 534 |
+
for bt in bin_targets:
|
| 535 |
+
bt["labels"] = torch.zeros_like(bt["labels"])
|
| 536 |
+
if self.assign_first_stage:
|
| 537 |
+
indices = self.stg1_assigner(enc_outputs, bin_targets)
|
| 538 |
+
else:
|
| 539 |
+
indices = self.matcher(enc_outputs, bin_targets)
|
| 540 |
+
for loss in self.losses:
|
| 541 |
+
if loss == "masks":
|
| 542 |
+
# Intermediate masks losses are too costly to compute, we ignore them.
|
| 543 |
+
continue
|
| 544 |
+
kwargs = {}
|
| 545 |
+
if loss == "labels":
|
| 546 |
+
# Logging is enabled only for the last layer
|
| 547 |
+
kwargs["log"] = False
|
| 548 |
+
l_dict = self.get_loss(
|
| 549 |
+
loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
|
| 550 |
+
)
|
| 551 |
+
l_dict = {k + f"_enc": v for k, v in l_dict.items()}
|
| 552 |
+
losses.update(l_dict)
|
| 553 |
+
|
| 554 |
+
return losses
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class PostProcess(nn.Module):
|
| 558 |
+
"""This module converts the model's output into the format expected by the coco api"""
|
| 559 |
+
|
| 560 |
+
@torch.no_grad()
|
| 561 |
+
def forward(self, outputs, target_sizes, num_topk=100):
|
| 562 |
+
"""Perform the computation
|
| 563 |
+
Parameters:
|
| 564 |
+
outputs: raw outputs of the model
|
| 565 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 566 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
| 567 |
+
For visualization, this should be the image size after data augment, but before padding
|
| 568 |
+
"""
|
| 569 |
+
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
|
| 570 |
+
|
| 571 |
+
assert len(out_logits) == len(target_sizes)
|
| 572 |
+
assert target_sizes.shape[1] == 2
|
| 573 |
+
|
| 574 |
+
prob = out_logits.sigmoid()
|
| 575 |
+
topk_values, topk_indexes = torch.topk(
|
| 576 |
+
prob.view(out_logits.shape[0], -1), num_topk, dim=1
|
| 577 |
+
)
|
| 578 |
+
scores = topk_values
|
| 579 |
+
topk_boxes = topk_indexes // out_logits.shape[2]
|
| 580 |
+
labels = topk_indexes % out_logits.shape[2]
|
| 581 |
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
| 582 |
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
| 583 |
+
|
| 584 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 585 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 586 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 587 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 588 |
+
|
| 589 |
+
results = [
|
| 590 |
+
{"scores": s, "labels": l, "boxes": b}
|
| 591 |
+
for s, l, b in zip(scores, labels, boxes)
|
| 592 |
+
]
|
| 593 |
+
|
| 594 |
+
return results
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class NMSPostProcess(nn.Module):
|
| 598 |
+
"""This module converts the model's output into the format expected by the coco api"""
|
| 599 |
+
|
| 600 |
+
@torch.no_grad()
|
| 601 |
+
def forward(
|
| 602 |
+
self,
|
| 603 |
+
outputs,
|
| 604 |
+
target_sizes,
|
| 605 |
+
num_topk=100,
|
| 606 |
+
soft_nms=False,
|
| 607 |
+
nms_thresh=0.7,
|
| 608 |
+
method="quad",
|
| 609 |
+
quad_scale=1.0,
|
| 610 |
+
):
|
| 611 |
+
"""Perform the computation
|
| 612 |
+
Parameters:
|
| 613 |
+
outputs: raw outputs of the model
|
| 614 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 615 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
| 616 |
+
For visualization, this should be the image size after data augment, but before padding
|
| 617 |
+
"""
|
| 618 |
+
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
|
| 619 |
+
bs, n_queries, n_cls = out_logits.shape
|
| 620 |
+
|
| 621 |
+
assert len(out_logits) == len(target_sizes)
|
| 622 |
+
assert target_sizes.shape[1] == 2
|
| 623 |
+
|
| 624 |
+
prob = out_logits.sigmoid()
|
| 625 |
+
|
| 626 |
+
all_scores = prob.view(bs, n_queries * n_cls).to(out_logits.device)
|
| 627 |
+
all_indexes = (
|
| 628 |
+
torch.arange(n_queries * n_cls)[None].repeat(bs, 1).to(out_logits.device)
|
| 629 |
+
)
|
| 630 |
+
all_boxes = all_indexes // out_logits.shape[2]
|
| 631 |
+
all_labels = all_indexes % out_logits.shape[2]
|
| 632 |
+
|
| 633 |
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
| 634 |
+
boxes = torch.gather(boxes, 1, all_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
| 635 |
+
|
| 636 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
| 637 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 638 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 639 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 640 |
+
|
| 641 |
+
results = []
|
| 642 |
+
for b in range(bs):
|
| 643 |
+
box = boxes[b]
|
| 644 |
+
score = all_scores[b]
|
| 645 |
+
lbls = all_labels[b]
|
| 646 |
+
|
| 647 |
+
if soft_nms:
|
| 648 |
+
if n_queries * n_cls > 2000:
|
| 649 |
+
pre_topk = score.topk(2000).indices
|
| 650 |
+
box = box[pre_topk]
|
| 651 |
+
score = score[pre_topk]
|
| 652 |
+
lbls = lbls[pre_topk]
|
| 653 |
+
# Apply soft-NMS to get indices and updated scores
|
| 654 |
+
keep_inds, updated_scores = batched_soft_nms(
|
| 655 |
+
box,
|
| 656 |
+
score,
|
| 657 |
+
lbls,
|
| 658 |
+
nms_thresh,
|
| 659 |
+
method=method,
|
| 660 |
+
quad_scale=quad_scale,
|
| 661 |
+
)[:num_topk]
|
| 662 |
+
|
| 663 |
+
results.append(
|
| 664 |
+
{
|
| 665 |
+
"scores": updated_scores,
|
| 666 |
+
"labels": lbls[keep_inds],
|
| 667 |
+
"boxes": box[keep_inds],
|
| 668 |
+
}
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
if n_queries * n_cls > 10000:
|
| 672 |
+
pre_topk = score.topk(10000).indices
|
| 673 |
+
box = box[pre_topk]
|
| 674 |
+
score = score[pre_topk]
|
| 675 |
+
lbls = lbls[pre_topk]
|
| 676 |
+
keep_inds = batched_nms(box, score, lbls, nms_thresh)[:num_topk]
|
| 677 |
+
results.append(
|
| 678 |
+
{
|
| 679 |
+
"scores": score[keep_inds],
|
| 680 |
+
"labels": lbls[keep_inds],
|
| 681 |
+
"boxes": box[keep_inds],
|
| 682 |
+
}
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
return results
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
class MLP(nn.Module):
|
| 689 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
| 690 |
+
|
| 691 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 692 |
+
super().__init__()
|
| 693 |
+
self.num_layers = num_layers
|
| 694 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 695 |
+
self.layers = nn.ModuleList(
|
| 696 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
def forward(self, x):
|
| 700 |
+
for i, layer in enumerate(self.layers):
|
| 701 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 702 |
+
return x
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def build(args):
|
| 706 |
+
# num_classes = 20 if args.dataset_file != 'coco' else 91
|
| 707 |
+
if args.dataset_file == "coco_panoptic":
|
| 708 |
+
num_classes = 250
|
| 709 |
+
elif args.dataset_file == "voc":
|
| 710 |
+
num_classes = 20
|
| 711 |
+
elif args.dataset_file == "objects365":
|
| 712 |
+
num_classes = 366
|
| 713 |
+
elif args.dataset_file == "lvis":
|
| 714 |
+
num_classes = 1204
|
| 715 |
+
else: # coco
|
| 716 |
+
num_classes = 91
|
| 717 |
+
device = torch.device(args.device)
|
| 718 |
+
|
| 719 |
+
backbone = build_backbone(args)
|
| 720 |
+
|
| 721 |
+
transformer = build_deforamble_transformer(args)
|
| 722 |
+
model = DeformableDETR(
|
| 723 |
+
backbone,
|
| 724 |
+
transformer,
|
| 725 |
+
num_classes=num_classes,
|
| 726 |
+
num_queries=args.num_queries,
|
| 727 |
+
num_feature_levels=args.num_feature_levels,
|
| 728 |
+
aux_loss=args.aux_loss,
|
| 729 |
+
with_box_refine=args.with_box_refine,
|
| 730 |
+
two_stage=args.two_stage,
|
| 731 |
+
)
|
| 732 |
+
if args.masks:
|
| 733 |
+
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
|
| 734 |
+
matcher = build_matcher(args)
|
| 735 |
+
weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
|
| 736 |
+
weight_dict["loss_giou"] = args.giou_loss_coef
|
| 737 |
+
if args.masks:
|
| 738 |
+
weight_dict["loss_mask"] = args.mask_loss_coef
|
| 739 |
+
weight_dict["loss_dice"] = args.dice_loss_coef
|
| 740 |
+
# TODO this is a hack
|
| 741 |
+
if args.aux_loss:
|
| 742 |
+
aux_weight_dict = {}
|
| 743 |
+
for i in range(args.dec_layers - 1):
|
| 744 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| 745 |
+
aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
|
| 746 |
+
weight_dict.update(aux_weight_dict)
|
| 747 |
+
|
| 748 |
+
losses = ["labels", "boxes", "cardinality"]
|
| 749 |
+
if args.masks:
|
| 750 |
+
losses += ["masks"]
|
| 751 |
+
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
|
| 752 |
+
criterion = SetCriterion(
|
| 753 |
+
num_classes,
|
| 754 |
+
matcher,
|
| 755 |
+
weight_dict,
|
| 756 |
+
losses,
|
| 757 |
+
focal_alpha=args.focal_alpha,
|
| 758 |
+
num_queries=args.num_queries,
|
| 759 |
+
assign_first_stage=args.assign_first_stage,
|
| 760 |
+
assign_second_stage=args.assign_second_stage,
|
| 761 |
+
use_fed_loss=args.use_fed_loss,
|
| 762 |
+
)
|
| 763 |
+
criterion.to(device)
|
| 764 |
+
if args.assign_second_stage:
|
| 765 |
+
postprocessors = {"bbox": NMSPostProcess()}
|
| 766 |
+
else:
|
| 767 |
+
postprocessors = {"bbox": PostProcess()}
|
| 768 |
+
if args.masks:
|
| 769 |
+
postprocessors["segm"] = PostProcessSegm()
|
| 770 |
+
if args.dataset_file == "coco_panoptic":
|
| 771 |
+
is_thing_map = {i: i <= 90 for i in range(201)}
|
| 772 |
+
postprocessors["panoptic"] = PostProcessPanoptic(
|
| 773 |
+
is_thing_map, threshold=0.85
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
return model, criterion, postprocessors
|
perception_models/apps/detection/DETA_pe/models/deformable_transformer.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
| 18 |
+
|
| 19 |
+
from util.misc import inverse_sigmoid
|
| 20 |
+
from models.ops.modules import MSDeformAttn
|
| 21 |
+
|
| 22 |
+
from torchvision.ops.boxes import batched_nms
|
| 23 |
+
from util.box_ops import box_cxcywh_to_xyxy
|
| 24 |
+
|
| 25 |
+
class DeformableTransformer(nn.Module):
|
| 26 |
+
def __init__(self, d_model=256, nhead=8,
|
| 27 |
+
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
|
| 28 |
+
activation="relu", return_intermediate_dec=False,
|
| 29 |
+
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
|
| 30 |
+
two_stage=False, two_stage_num_proposals=300,
|
| 31 |
+
assign_first_stage=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
self.nhead = nhead
|
| 36 |
+
self.two_stage = two_stage
|
| 37 |
+
self.two_stage_num_proposals = two_stage_num_proposals
|
| 38 |
+
self.assign_first_stage = assign_first_stage
|
| 39 |
+
|
| 40 |
+
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
|
| 41 |
+
dropout, activation,
|
| 42 |
+
num_feature_levels, nhead, enc_n_points)
|
| 43 |
+
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
|
| 44 |
+
|
| 45 |
+
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
| 46 |
+
dropout, activation,
|
| 47 |
+
num_feature_levels, nhead, dec_n_points)
|
| 48 |
+
self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
|
| 49 |
+
|
| 50 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
| 51 |
+
|
| 52 |
+
if two_stage:
|
| 53 |
+
self.enc_output = nn.Linear(d_model, d_model)
|
| 54 |
+
self.enc_output_norm = nn.LayerNorm(d_model)
|
| 55 |
+
self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
|
| 56 |
+
self.pos_trans_norm = nn.LayerNorm(d_model * 2)
|
| 57 |
+
self.pix_trans = nn.Linear(d_model, d_model)
|
| 58 |
+
self.pix_trans_norm = nn.LayerNorm(d_model)
|
| 59 |
+
else:
|
| 60 |
+
self.reference_points = nn.Linear(d_model, 2)
|
| 61 |
+
|
| 62 |
+
self._reset_parameters()
|
| 63 |
+
|
| 64 |
+
def _reset_parameters(self):
|
| 65 |
+
for p in self.parameters():
|
| 66 |
+
if p.dim() > 1:
|
| 67 |
+
nn.init.xavier_uniform_(p)
|
| 68 |
+
for m in self.modules():
|
| 69 |
+
if isinstance(m, MSDeformAttn):
|
| 70 |
+
m._reset_parameters()
|
| 71 |
+
if not self.two_stage:
|
| 72 |
+
xavier_uniform_(self.reference_points.weight.data, gain=1.0)
|
| 73 |
+
constant_(self.reference_points.bias.data, 0.)
|
| 74 |
+
normal_(self.level_embed)
|
| 75 |
+
|
| 76 |
+
def get_proposal_pos_embed(self, proposals):
|
| 77 |
+
num_pos_feats = 128
|
| 78 |
+
temperature = 10000
|
| 79 |
+
scale = 2 * math.pi
|
| 80 |
+
|
| 81 |
+
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
| 82 |
+
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
| 83 |
+
# N, L, 4
|
| 84 |
+
proposals = proposals.sigmoid() * scale
|
| 85 |
+
# N, L, 4, 128
|
| 86 |
+
pos = proposals[:, :, :, None] / dim_t
|
| 87 |
+
# N, L, 4, 64, 2
|
| 88 |
+
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
|
| 89 |
+
return pos
|
| 90 |
+
|
| 91 |
+
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
|
| 92 |
+
N_, S_, C_ = memory.shape
|
| 93 |
+
base_scale = 4.0
|
| 94 |
+
proposals = []
|
| 95 |
+
_cur = 0
|
| 96 |
+
level_ids = []
|
| 97 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| 98 |
+
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
|
| 99 |
+
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
| 100 |
+
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
| 101 |
+
|
| 102 |
+
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
| 103 |
+
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
|
| 104 |
+
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
| 105 |
+
|
| 106 |
+
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
|
| 107 |
+
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
| 108 |
+
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
|
| 109 |
+
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
| 110 |
+
proposals.append(proposal)
|
| 111 |
+
_cur += (H_ * W_)
|
| 112 |
+
level_ids.append(grid.new_ones(H_ * W_, dtype=torch.long) * lvl)
|
| 113 |
+
output_proposals = torch.cat(proposals, 1)
|
| 114 |
+
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
| 115 |
+
output_proposals = torch.log(output_proposals / (1 - output_proposals))
|
| 116 |
+
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
|
| 117 |
+
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
|
| 118 |
+
|
| 119 |
+
output_memory = memory
|
| 120 |
+
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
| 121 |
+
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
| 122 |
+
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
| 123 |
+
level_ids = torch.cat(level_ids)
|
| 124 |
+
return output_memory, output_proposals, level_ids
|
| 125 |
+
|
| 126 |
+
def get_valid_ratio(self, mask):
|
| 127 |
+
_, H, W = mask.shape
|
| 128 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
| 129 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
| 130 |
+
valid_ratio_h = valid_H.float() / H
|
| 131 |
+
valid_ratio_w = valid_W.float() / W
|
| 132 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
| 133 |
+
return valid_ratio
|
| 134 |
+
|
| 135 |
+
def forward(self, srcs, masks, pos_embeds, query_embed=None):
|
| 136 |
+
assert self.two_stage or query_embed is not None
|
| 137 |
+
|
| 138 |
+
# prepare input for encoder
|
| 139 |
+
src_flatten = []
|
| 140 |
+
mask_flatten = []
|
| 141 |
+
lvl_pos_embed_flatten = []
|
| 142 |
+
spatial_shapes = []
|
| 143 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
| 144 |
+
bs, c, h, w = src.shape
|
| 145 |
+
spatial_shape = (h, w)
|
| 146 |
+
spatial_shapes.append(spatial_shape)
|
| 147 |
+
src = src.flatten(2).transpose(1, 2)
|
| 148 |
+
mask = mask.flatten(1)
|
| 149 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
| 150 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
| 151 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
| 152 |
+
src_flatten.append(src)
|
| 153 |
+
mask_flatten.append(mask)
|
| 154 |
+
src_flatten = torch.cat(src_flatten, 1)
|
| 155 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
| 156 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
| 157 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
| 158 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| 159 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
| 160 |
+
|
| 161 |
+
# encoder
|
| 162 |
+
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
|
| 163 |
+
|
| 164 |
+
# prepare input for decoder
|
| 165 |
+
bs, _, c = memory.shape
|
| 166 |
+
if self.two_stage:
|
| 167 |
+
output_memory, output_proposals, level_ids = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
|
| 168 |
+
|
| 169 |
+
# hack implementation for two-stage Deformable DETR
|
| 170 |
+
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
|
| 171 |
+
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
|
| 172 |
+
|
| 173 |
+
topk = self.two_stage_num_proposals
|
| 174 |
+
proposal_logit = enc_outputs_class[..., 0]
|
| 175 |
+
|
| 176 |
+
if self.assign_first_stage:
|
| 177 |
+
proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1)
|
| 178 |
+
topk_proposals = []
|
| 179 |
+
for b in range(bs):
|
| 180 |
+
prop_boxes_b = proposal_boxes[b]
|
| 181 |
+
prop_logits_b = proposal_logit[b]
|
| 182 |
+
|
| 183 |
+
# pre-nms per-level topk
|
| 184 |
+
pre_nms_topk = 1000
|
| 185 |
+
pre_nms_inds = []
|
| 186 |
+
for lvl in range(len(spatial_shapes)):
|
| 187 |
+
lvl_mask = level_ids == lvl
|
| 188 |
+
pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1])
|
| 189 |
+
pre_nms_inds = torch.cat(pre_nms_inds)
|
| 190 |
+
|
| 191 |
+
# nms on topk indices
|
| 192 |
+
post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], prop_logits_b[pre_nms_inds], level_ids[pre_nms_inds], 0.9)
|
| 193 |
+
keep_inds = pre_nms_inds[post_nms_inds]
|
| 194 |
+
|
| 195 |
+
if len(keep_inds) < self.two_stage_num_proposals:
|
| 196 |
+
print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}, running naive topk')
|
| 197 |
+
keep_inds = torch.topk(proposal_logit[b], topk)[1]
|
| 198 |
+
|
| 199 |
+
# keep top Q/L indices for L levels
|
| 200 |
+
q_per_l = topk // len(spatial_shapes)
|
| 201 |
+
is_level_ordered = level_ids[keep_inds][None] == torch.arange(len(spatial_shapes), device=level_ids.device)[:,None] # LS
|
| 202 |
+
keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS
|
| 203 |
+
keep_inds_mask = keep_inds_mask.any(0) # S
|
| 204 |
+
|
| 205 |
+
# pad to Q indices (might let ones filtered from pre-nms sneak by... unlikely because we pick high conf anyways)
|
| 206 |
+
if keep_inds_mask.sum() < topk:
|
| 207 |
+
num_to_add = topk - keep_inds_mask.sum()
|
| 208 |
+
pad_inds = (~keep_inds_mask).nonzero()[:num_to_add]
|
| 209 |
+
keep_inds_mask[pad_inds] = True
|
| 210 |
+
|
| 211 |
+
# index
|
| 212 |
+
keep_inds_topk = keep_inds[keep_inds_mask]
|
| 213 |
+
topk_proposals.append(keep_inds_topk)
|
| 214 |
+
topk_proposals = torch.stack(topk_proposals)
|
| 215 |
+
else:
|
| 216 |
+
topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1]
|
| 217 |
+
|
| 218 |
+
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
|
| 219 |
+
topk_coords_unact = topk_coords_unact.detach()
|
| 220 |
+
reference_points = topk_coords_unact.sigmoid()
|
| 221 |
+
init_reference_out = reference_points
|
| 222 |
+
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
|
| 223 |
+
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
|
| 224 |
+
|
| 225 |
+
topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach()
|
| 226 |
+
tgt = tgt + self.pix_trans_norm(self.pix_trans(topk_feats))
|
| 227 |
+
else:
|
| 228 |
+
query_embed, tgt = torch.split(query_embed, c, dim=1)
|
| 229 |
+
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
|
| 230 |
+
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
|
| 231 |
+
reference_points = self.reference_points(query_embed).sigmoid()
|
| 232 |
+
init_reference_out = reference_points
|
| 233 |
+
|
| 234 |
+
# decoder
|
| 235 |
+
hs, inter_references = self.decoder(tgt, reference_points, memory,
|
| 236 |
+
spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)
|
| 237 |
+
|
| 238 |
+
inter_references_out = inter_references
|
| 239 |
+
if self.two_stage:
|
| 240 |
+
return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, output_proposals.sigmoid()
|
| 241 |
+
return hs, init_reference_out, inter_references_out, None, None, None
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class DeformableTransformerEncoderLayer(nn.Module):
|
| 245 |
+
def __init__(self,
|
| 246 |
+
d_model=256, d_ffn=1024,
|
| 247 |
+
dropout=0.1, activation="relu",
|
| 248 |
+
n_levels=4, n_heads=8, n_points=4):
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
# self attention
|
| 252 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
| 253 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 254 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 255 |
+
|
| 256 |
+
# ffn
|
| 257 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
| 258 |
+
self.activation = _get_activation_fn(activation)
|
| 259 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 260 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
| 261 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 262 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def with_pos_embed(tensor, pos):
|
| 266 |
+
return tensor if pos is None else tensor + pos
|
| 267 |
+
|
| 268 |
+
def forward_ffn(self, src):
|
| 269 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
| 270 |
+
src = src + self.dropout3(src2)
|
| 271 |
+
src = self.norm2(src)
|
| 272 |
+
return src
|
| 273 |
+
|
| 274 |
+
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
|
| 275 |
+
# self attention
|
| 276 |
+
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
|
| 277 |
+
src = src + self.dropout1(src2)
|
| 278 |
+
src = self.norm1(src)
|
| 279 |
+
|
| 280 |
+
# ffn
|
| 281 |
+
src = self.forward_ffn(src)
|
| 282 |
+
|
| 283 |
+
return src
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class DeformableTransformerEncoder(nn.Module):
|
| 287 |
+
def __init__(self, encoder_layer, num_layers):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 290 |
+
self.num_layers = num_layers
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
| 294 |
+
reference_points_list = []
|
| 295 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| 296 |
+
|
| 297 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
| 298 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
| 299 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
| 300 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
| 301 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
| 302 |
+
reference_points_list.append(ref)
|
| 303 |
+
reference_points = torch.cat(reference_points_list, 1)
|
| 304 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
| 305 |
+
return reference_points
|
| 306 |
+
|
| 307 |
+
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
|
| 308 |
+
output = src
|
| 309 |
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
| 310 |
+
for _, layer in enumerate(self.layers):
|
| 311 |
+
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
| 312 |
+
|
| 313 |
+
return output
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
| 317 |
+
def __init__(self, d_model=256, d_ffn=1024,
|
| 318 |
+
dropout=0.1, activation="relu",
|
| 319 |
+
n_levels=4, n_heads=8, n_points=4):
|
| 320 |
+
super().__init__()
|
| 321 |
+
|
| 322 |
+
# cross attention
|
| 323 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
| 324 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 325 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 326 |
+
|
| 327 |
+
# self attention
|
| 328 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
| 329 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 330 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 331 |
+
|
| 332 |
+
# ffn
|
| 333 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
| 334 |
+
self.activation = _get_activation_fn(activation)
|
| 335 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 336 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
| 337 |
+
self.dropout4 = nn.Dropout(dropout)
|
| 338 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def with_pos_embed(tensor, pos):
|
| 342 |
+
return tensor if pos is None else tensor + pos
|
| 343 |
+
|
| 344 |
+
def forward_ffn(self, tgt):
|
| 345 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
| 346 |
+
tgt = tgt + self.dropout4(tgt2)
|
| 347 |
+
tgt = self.norm3(tgt)
|
| 348 |
+
return tgt
|
| 349 |
+
|
| 350 |
+
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
|
| 351 |
+
# self attention
|
| 352 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 353 |
+
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
|
| 354 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 355 |
+
tgt = self.norm2(tgt)
|
| 356 |
+
|
| 357 |
+
# cross attention
|
| 358 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
|
| 359 |
+
reference_points,
|
| 360 |
+
src, src_spatial_shapes, level_start_index, src_padding_mask)
|
| 361 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 362 |
+
tgt = self.norm1(tgt)
|
| 363 |
+
|
| 364 |
+
# ffn
|
| 365 |
+
tgt = self.forward_ffn(tgt)
|
| 366 |
+
|
| 367 |
+
return tgt
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class DeformableTransformerDecoder(nn.Module):
|
| 371 |
+
def __init__(self, decoder_layer, num_layers, return_intermediate=False):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 374 |
+
self.num_layers = num_layers
|
| 375 |
+
self.return_intermediate = return_intermediate
|
| 376 |
+
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
| 377 |
+
self.bbox_embed = None
|
| 378 |
+
self.class_embed = None
|
| 379 |
+
|
| 380 |
+
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
|
| 381 |
+
query_pos=None, src_padding_mask=None):
|
| 382 |
+
output = tgt
|
| 383 |
+
|
| 384 |
+
intermediate = []
|
| 385 |
+
intermediate_reference_points = []
|
| 386 |
+
for lid, layer in enumerate(self.layers):
|
| 387 |
+
if reference_points.shape[-1] == 4:
|
| 388 |
+
reference_points_input = reference_points[:, :, None] \
|
| 389 |
+
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
|
| 390 |
+
else:
|
| 391 |
+
assert reference_points.shape[-1] == 2
|
| 392 |
+
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
|
| 393 |
+
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
|
| 394 |
+
|
| 395 |
+
# hack implementation for iterative bounding box refinement
|
| 396 |
+
if self.bbox_embed is not None:
|
| 397 |
+
tmp = self.bbox_embed[lid](output)
|
| 398 |
+
if reference_points.shape[-1] == 4:
|
| 399 |
+
new_reference_points = tmp + inverse_sigmoid(reference_points)
|
| 400 |
+
new_reference_points = new_reference_points.sigmoid()
|
| 401 |
+
else:
|
| 402 |
+
assert reference_points.shape[-1] == 2
|
| 403 |
+
new_reference_points = tmp
|
| 404 |
+
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
|
| 405 |
+
new_reference_points = new_reference_points.sigmoid()
|
| 406 |
+
reference_points = new_reference_points.detach()
|
| 407 |
+
|
| 408 |
+
if self.return_intermediate:
|
| 409 |
+
intermediate.append(output)
|
| 410 |
+
intermediate_reference_points.append(reference_points)
|
| 411 |
+
|
| 412 |
+
if self.return_intermediate:
|
| 413 |
+
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
| 414 |
+
|
| 415 |
+
return output, reference_points
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _get_clones(module, N):
|
| 419 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _get_activation_fn(activation):
|
| 423 |
+
"""Return an activation function given a string"""
|
| 424 |
+
if activation == "relu":
|
| 425 |
+
return F.relu
|
| 426 |
+
if activation == "gelu":
|
| 427 |
+
return F.gelu
|
| 428 |
+
if activation == "glu":
|
| 429 |
+
return F.glu
|
| 430 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def build_deforamble_transformer(args):
|
| 434 |
+
return DeformableTransformer(
|
| 435 |
+
d_model=args.hidden_dim,
|
| 436 |
+
nhead=args.nheads,
|
| 437 |
+
num_encoder_layers=args.enc_layers,
|
| 438 |
+
num_decoder_layers=args.dec_layers,
|
| 439 |
+
dim_feedforward=args.dim_feedforward,
|
| 440 |
+
dropout=args.dropout,
|
| 441 |
+
activation="relu",
|
| 442 |
+
return_intermediate_dec=True,
|
| 443 |
+
num_feature_levels=args.num_feature_levels,
|
| 444 |
+
dec_n_points=args.dec_n_points,
|
| 445 |
+
enc_n_points=args.enc_n_points,
|
| 446 |
+
two_stage=args.two_stage,
|
| 447 |
+
two_stage_num_proposals=args.num_queries,
|
| 448 |
+
assign_first_stage=args.assign_first_stage,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
perception_models/apps/detection/DETA_pe/models/matcher.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
from scipy.optimize import linear_sum_assignment
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class HungarianMatcher(nn.Module):
|
| 21 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
| 22 |
+
|
| 23 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
| 24 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
| 25 |
+
while the others are un-matched (and thus treated as non-objects).
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
cost_class: float = 1,
|
| 30 |
+
cost_bbox: float = 1,
|
| 31 |
+
cost_giou: float = 1):
|
| 32 |
+
"""Creates the matcher
|
| 33 |
+
|
| 34 |
+
Params:
|
| 35 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
| 36 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
| 37 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.cost_class = cost_class
|
| 41 |
+
self.cost_bbox = cost_bbox
|
| 42 |
+
self.cost_giou = cost_giou
|
| 43 |
+
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
|
| 44 |
+
|
| 45 |
+
def forward(self, outputs, targets):
|
| 46 |
+
""" Performs the matching
|
| 47 |
+
|
| 48 |
+
Params:
|
| 49 |
+
outputs: This is a dict that contains at least these entries:
|
| 50 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 51 |
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
| 52 |
+
|
| 53 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 54 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
| 55 |
+
objects in the target) containing the class labels
|
| 56 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 60 |
+
- index_i is the indices of the selected predictions (in order)
|
| 61 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 62 |
+
For each batch element, it holds:
|
| 63 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 64 |
+
"""
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
| 67 |
+
|
| 68 |
+
# We flatten to compute the cost matrices in a batch
|
| 69 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
|
| 70 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
| 71 |
+
|
| 72 |
+
# Also concat the target labels and boxes
|
| 73 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
| 74 |
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
| 75 |
+
|
| 76 |
+
# Compute the classification cost.
|
| 77 |
+
alpha = 0.25
|
| 78 |
+
gamma = 2.0
|
| 79 |
+
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
|
| 80 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
| 81 |
+
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
|
| 82 |
+
|
| 83 |
+
# Compute the L1 cost between boxes
|
| 84 |
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
| 85 |
+
|
| 86 |
+
# Compute the giou cost betwen boxes
|
| 87 |
+
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
|
| 88 |
+
box_cxcywh_to_xyxy(tgt_bbox))
|
| 89 |
+
|
| 90 |
+
# Final cost matrix
|
| 91 |
+
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
|
| 92 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 93 |
+
|
| 94 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 95 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
| 96 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def build_matcher(args):
|
| 100 |
+
return HungarianMatcher(cost_class=args.set_cost_class,
|
| 101 |
+
cost_bbox=args.set_cost_bbox,
|
| 102 |
+
cost_giou=args.set_cost_giou)
|
perception_models/apps/detection/DETA_pe/models/ops/functions/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from .ms_deform_attn_func import ms_deform_attn_core_pytorch, MSDeformAttnFunction
|
perception_models/apps/detection/DETA_pe/models/ops/functions/ms_deform_attn_func.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import, division, print_function
|
| 10 |
+
|
| 11 |
+
import MultiScaleDeformableAttention as MSDA
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.autograd import Function
|
| 16 |
+
from torch.autograd.function import once_differentiable
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MSDeformAttnFunction(Function):
|
| 20 |
+
@staticmethod
|
| 21 |
+
def forward(
|
| 22 |
+
ctx,
|
| 23 |
+
value,
|
| 24 |
+
value_spatial_shapes,
|
| 25 |
+
value_level_start_index,
|
| 26 |
+
sampling_locations,
|
| 27 |
+
attention_weights,
|
| 28 |
+
im2col_step,
|
| 29 |
+
):
|
| 30 |
+
ctx.im2col_step = im2col_step
|
| 31 |
+
output = MSDA.ms_deform_attn_forward(
|
| 32 |
+
value,
|
| 33 |
+
value_spatial_shapes,
|
| 34 |
+
value_level_start_index,
|
| 35 |
+
sampling_locations,
|
| 36 |
+
attention_weights,
|
| 37 |
+
ctx.im2col_step,
|
| 38 |
+
)
|
| 39 |
+
ctx.save_for_backward(
|
| 40 |
+
value,
|
| 41 |
+
value_spatial_shapes,
|
| 42 |
+
value_level_start_index,
|
| 43 |
+
sampling_locations,
|
| 44 |
+
attention_weights,
|
| 45 |
+
)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
@once_differentiable
|
| 50 |
+
def backward(ctx, grad_output):
|
| 51 |
+
(
|
| 52 |
+
value,
|
| 53 |
+
value_spatial_shapes,
|
| 54 |
+
value_level_start_index,
|
| 55 |
+
sampling_locations,
|
| 56 |
+
attention_weights,
|
| 57 |
+
) = ctx.saved_tensors
|
| 58 |
+
grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
|
| 59 |
+
value,
|
| 60 |
+
value_spatial_shapes,
|
| 61 |
+
value_level_start_index,
|
| 62 |
+
sampling_locations,
|
| 63 |
+
attention_weights,
|
| 64 |
+
grad_output,
|
| 65 |
+
ctx.im2col_step,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ms_deform_attn_core_pytorch(
|
| 72 |
+
value, value_spatial_shapes, sampling_locations, attention_weights
|
| 73 |
+
):
|
| 74 |
+
# for debug and test only,
|
| 75 |
+
# need to use cuda version instead
|
| 76 |
+
N_, S_, M_, D_ = value.shape
|
| 77 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
| 78 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
| 79 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 80 |
+
sampling_value_list = []
|
| 81 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
| 82 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
| 83 |
+
value_l_ = (
|
| 84 |
+
value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
|
| 85 |
+
)
|
| 86 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
| 87 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
| 88 |
+
# N_*M_, D_, Lq_, P_
|
| 89 |
+
sampling_value_l_ = F.grid_sample(
|
| 90 |
+
value_l_,
|
| 91 |
+
sampling_grid_l_,
|
| 92 |
+
mode="bilinear",
|
| 93 |
+
padding_mode="zeros",
|
| 94 |
+
align_corners=False,
|
| 95 |
+
)
|
| 96 |
+
sampling_value_list.append(sampling_value_l_)
|
| 97 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
| 98 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
| 99 |
+
N_ * M_, 1, Lq_, L_ * P_
|
| 100 |
+
)
|
| 101 |
+
output = (
|
| 102 |
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
| 103 |
+
.sum(-1)
|
| 104 |
+
.view(N_, M_ * D_, Lq_)
|
| 105 |
+
)
|
| 106 |
+
return output.transpose(1, 2).contiguous()
|
perception_models/apps/detection/DETA_pe/models/ops/make.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# ------------------------------------------------------------------------------------------------
|
| 3 |
+
# Deformable DETR
|
| 4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
# ------------------------------------------------------------------------------------------------
|
| 7 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
# ------------------------------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
python setup.py build install
|
perception_models/apps/detection/DETA_pe/models/ops/modules/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from .ms_deform_attn import MSDeformAttn
|
perception_models/apps/detection/DETA_pe/models/ops/modules/ms_deform_attn.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import, division, print_function
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn.init import constant_, xavier_uniform_
|
| 19 |
+
|
| 20 |
+
from ..functions import ms_deform_attn_core_pytorch, MSDeformAttnFunction
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _is_power_of_2(n):
|
| 24 |
+
if (not isinstance(n, int)) or (n < 0):
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
|
| 27 |
+
)
|
| 28 |
+
return (n & (n - 1) == 0) and n != 0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MSDeformAttn(nn.Module):
|
| 32 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
| 33 |
+
"""
|
| 34 |
+
Multi-Scale Deformable Attention Module
|
| 35 |
+
:param d_model hidden dimension
|
| 36 |
+
:param n_levels number of feature levels
|
| 37 |
+
:param n_heads number of attention heads
|
| 38 |
+
:param n_points number of sampling points per attention head per feature level
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
if d_model % n_heads != 0:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"d_model must be divisible by n_heads, but got {} and {}".format(
|
| 44 |
+
d_model, n_heads
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
_d_per_head = d_model // n_heads
|
| 48 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
| 49 |
+
if not _is_power_of_2(_d_per_head):
|
| 50 |
+
warnings.warn(
|
| 51 |
+
"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
| 52 |
+
"which is more efficient in our CUDA implementation."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.im2col_step = 64
|
| 56 |
+
|
| 57 |
+
self.d_model = d_model
|
| 58 |
+
self.n_levels = n_levels
|
| 59 |
+
self.n_heads = n_heads
|
| 60 |
+
self.n_points = n_points
|
| 61 |
+
|
| 62 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
| 63 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
| 64 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
| 65 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
| 66 |
+
|
| 67 |
+
self._reset_parameters()
|
| 68 |
+
|
| 69 |
+
def _reset_parameters(self):
|
| 70 |
+
constant_(self.sampling_offsets.weight.data, 0.0)
|
| 71 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
|
| 72 |
+
2.0 * math.pi / self.n_heads
|
| 73 |
+
)
|
| 74 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
| 75 |
+
grid_init = (
|
| 76 |
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
| 77 |
+
.view(self.n_heads, 1, 1, 2)
|
| 78 |
+
.repeat(1, self.n_levels, self.n_points, 1)
|
| 79 |
+
)
|
| 80 |
+
for i in range(self.n_points):
|
| 81 |
+
grid_init[:, :, i, :] *= i + 1
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
| 84 |
+
constant_(self.attention_weights.weight.data, 0.0)
|
| 85 |
+
constant_(self.attention_weights.bias.data, 0.0)
|
| 86 |
+
xavier_uniform_(self.value_proj.weight.data)
|
| 87 |
+
constant_(self.value_proj.bias.data, 0.0)
|
| 88 |
+
xavier_uniform_(self.output_proj.weight.data)
|
| 89 |
+
constant_(self.output_proj.bias.data, 0.0)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
query,
|
| 94 |
+
reference_points,
|
| 95 |
+
input_flatten,
|
| 96 |
+
input_spatial_shapes,
|
| 97 |
+
input_level_start_index,
|
| 98 |
+
input_padding_mask=None,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
:param query (N, Length_{query}, C)
|
| 102 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
| 103 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
| 104 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
| 105 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
| 106 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
| 107 |
+
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
| 108 |
+
|
| 109 |
+
:return output (N, Length_{query}, C)
|
| 110 |
+
"""
|
| 111 |
+
N, Len_q, _ = query.shape
|
| 112 |
+
N, Len_in, _ = input_flatten.shape
|
| 113 |
+
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
| 114 |
+
|
| 115 |
+
value = self.value_proj(input_flatten)
|
| 116 |
+
if input_padding_mask is not None:
|
| 117 |
+
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
| 118 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
| 119 |
+
sampling_offsets = self.sampling_offsets(query).view(
|
| 120 |
+
N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
|
| 121 |
+
)
|
| 122 |
+
attention_weights = self.attention_weights(query).view(
|
| 123 |
+
N, Len_q, self.n_heads, self.n_levels * self.n_points
|
| 124 |
+
)
|
| 125 |
+
attention_weights = F.softmax(attention_weights, -1).view(
|
| 126 |
+
N, Len_q, self.n_heads, self.n_levels, self.n_points
|
| 127 |
+
)
|
| 128 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
| 129 |
+
if reference_points.shape[-1] == 2:
|
| 130 |
+
offset_normalizer = torch.stack(
|
| 131 |
+
[input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
|
| 132 |
+
)
|
| 133 |
+
sampling_locations = (
|
| 134 |
+
reference_points[:, :, None, :, None, :]
|
| 135 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 136 |
+
)
|
| 137 |
+
elif reference_points.shape[-1] == 4:
|
| 138 |
+
sampling_locations = (
|
| 139 |
+
reference_points[:, :, None, :, None, :2]
|
| 140 |
+
+ sampling_offsets
|
| 141 |
+
/ self.n_points
|
| 142 |
+
* reference_points[:, :, None, :, None, 2:]
|
| 143 |
+
* 0.5
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(
|
| 148 |
+
reference_points.shape[-1]
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
output = MSDeformAttnFunction.apply(
|
| 152 |
+
value,
|
| 153 |
+
input_spatial_shapes,
|
| 154 |
+
input_level_start_index,
|
| 155 |
+
sampling_locations,
|
| 156 |
+
attention_weights,
|
| 157 |
+
self.im2col_step,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
output = self.output_proj(output)
|
| 161 |
+
return output
|
perception_models/apps/detection/DETA_pe/models/ops/setup.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
| 15 |
+
from torch.utils.cpp_extension import CppExtension
|
| 16 |
+
from torch.utils.cpp_extension import CUDAExtension
|
| 17 |
+
|
| 18 |
+
from setuptools import find_packages
|
| 19 |
+
from setuptools import setup
|
| 20 |
+
|
| 21 |
+
requirements = ["torch", "torchvision"]
|
| 22 |
+
|
| 23 |
+
def get_extensions():
|
| 24 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
extensions_dir = os.path.join(this_dir, "src")
|
| 26 |
+
|
| 27 |
+
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
| 28 |
+
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
|
| 29 |
+
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
|
| 30 |
+
|
| 31 |
+
sources = main_file + source_cpu
|
| 32 |
+
extension = CppExtension
|
| 33 |
+
extra_compile_args = {"cxx": []}
|
| 34 |
+
define_macros = []
|
| 35 |
+
|
| 36 |
+
if torch.cuda.is_available() and CUDA_HOME is not None:
|
| 37 |
+
extension = CUDAExtension
|
| 38 |
+
sources += source_cuda
|
| 39 |
+
define_macros += [("WITH_CUDA", None)]
|
| 40 |
+
extra_compile_args["nvcc"] = [
|
| 41 |
+
"-DCUDA_HAS_FP16=1",
|
| 42 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
| 43 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
| 44 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
| 45 |
+
]
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError('Cuda is not availabel')
|
| 48 |
+
|
| 49 |
+
sources = [os.path.join(extensions_dir, s) for s in sources]
|
| 50 |
+
include_dirs = [extensions_dir]
|
| 51 |
+
ext_modules = [
|
| 52 |
+
extension(
|
| 53 |
+
"MultiScaleDeformableAttention",
|
| 54 |
+
sources,
|
| 55 |
+
include_dirs=include_dirs,
|
| 56 |
+
define_macros=define_macros,
|
| 57 |
+
extra_compile_args=extra_compile_args,
|
| 58 |
+
)
|
| 59 |
+
]
|
| 60 |
+
return ext_modules
|
| 61 |
+
|
| 62 |
+
setup(
|
| 63 |
+
name="MultiScaleDeformableAttention",
|
| 64 |
+
version="1.0",
|
| 65 |
+
author="Weijie Su",
|
| 66 |
+
url="https://github.com/fundamentalvision/Deformable-DETR",
|
| 67 |
+
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
|
| 68 |
+
packages=find_packages(exclude=("configs", "tests",)),
|
| 69 |
+
ext_modules=get_extensions(),
|
| 70 |
+
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
| 71 |
+
)
|
perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.cpp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
#include <ATen/ATen.h>
|
| 14 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
at::Tensor
|
| 18 |
+
ms_deform_attn_cpu_forward(
|
| 19 |
+
const at::Tensor &value,
|
| 20 |
+
const at::Tensor &spatial_shapes,
|
| 21 |
+
const at::Tensor &level_start_index,
|
| 22 |
+
const at::Tensor &sampling_loc,
|
| 23 |
+
const at::Tensor &attn_weight,
|
| 24 |
+
const int im2col_step)
|
| 25 |
+
{
|
| 26 |
+
AT_ERROR("Not implement on cpu");
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
std::vector<at::Tensor>
|
| 30 |
+
ms_deform_attn_cpu_backward(
|
| 31 |
+
const at::Tensor &value,
|
| 32 |
+
const at::Tensor &spatial_shapes,
|
| 33 |
+
const at::Tensor &level_start_index,
|
| 34 |
+
const at::Tensor &sampling_loc,
|
| 35 |
+
const at::Tensor &attn_weight,
|
| 36 |
+
const at::Tensor &grad_output,
|
| 37 |
+
const int im2col_step)
|
| 38 |
+
{
|
| 39 |
+
AT_ERROR("Not implement on cpu");
|
| 40 |
+
}
|
| 41 |
+
|
perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor
|
| 15 |
+
ms_deform_attn_cpu_forward(
|
| 16 |
+
const at::Tensor &value,
|
| 17 |
+
const at::Tensor &spatial_shapes,
|
| 18 |
+
const at::Tensor &level_start_index,
|
| 19 |
+
const at::Tensor &sampling_loc,
|
| 20 |
+
const at::Tensor &attn_weight,
|
| 21 |
+
const int im2col_step);
|
| 22 |
+
|
| 23 |
+
std::vector<at::Tensor>
|
| 24 |
+
ms_deform_attn_cpu_backward(
|
| 25 |
+
const at::Tensor &value,
|
| 26 |
+
const at::Tensor &spatial_shapes,
|
| 27 |
+
const at::Tensor &level_start_index,
|
| 28 |
+
const at::Tensor &sampling_loc,
|
| 29 |
+
const at::Tensor &attn_weight,
|
| 30 |
+
const at::Tensor &grad_output,
|
| 31 |
+
const int im2col_step);
|
| 32 |
+
|
| 33 |
+
|
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.cu
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <vector>
|
| 12 |
+
#include "cuda/ms_deform_im2col_cuda.cuh"
|
| 13 |
+
|
| 14 |
+
#include <ATen/ATen.h>
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
#include <cuda.h>
|
| 17 |
+
#include <cuda_runtime.h>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 21 |
+
const at::Tensor &value,
|
| 22 |
+
const at::Tensor &spatial_shapes,
|
| 23 |
+
const at::Tensor &level_start_index,
|
| 24 |
+
const at::Tensor &sampling_loc,
|
| 25 |
+
const at::Tensor &attn_weight,
|
| 26 |
+
const int im2col_step)
|
| 27 |
+
{
|
| 28 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 29 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 30 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 31 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 32 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 33 |
+
|
| 34 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 35 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 36 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 37 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 38 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 39 |
+
|
| 40 |
+
const int batch = value.size(0);
|
| 41 |
+
const int spatial_size = value.size(1);
|
| 42 |
+
const int num_heads = value.size(2);
|
| 43 |
+
const int channels = value.size(3);
|
| 44 |
+
|
| 45 |
+
const int num_levels = spatial_shapes.size(0);
|
| 46 |
+
|
| 47 |
+
const int num_query = sampling_loc.size(1);
|
| 48 |
+
const int num_point = sampling_loc.size(4);
|
| 49 |
+
|
| 50 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 51 |
+
|
| 52 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 53 |
+
|
| 54 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
| 55 |
+
|
| 56 |
+
const int batch_n = im2col_step_;
|
| 57 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 58 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 59 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 60 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 61 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 62 |
+
{
|
| 63 |
+
auto columns = output_n.select(0, n);
|
| 64 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
| 65 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
| 66 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 67 |
+
spatial_shapes.data<int64_t>(),
|
| 68 |
+
level_start_index.data<int64_t>(),
|
| 69 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 70 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 71 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 72 |
+
columns.data<scalar_t>());
|
| 73 |
+
|
| 74 |
+
}));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
output = output.view({batch, num_query, num_heads*channels});
|
| 78 |
+
|
| 79 |
+
return output;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 84 |
+
const at::Tensor &value,
|
| 85 |
+
const at::Tensor &spatial_shapes,
|
| 86 |
+
const at::Tensor &level_start_index,
|
| 87 |
+
const at::Tensor &sampling_loc,
|
| 88 |
+
const at::Tensor &attn_weight,
|
| 89 |
+
const at::Tensor &grad_output,
|
| 90 |
+
const int im2col_step)
|
| 91 |
+
{
|
| 92 |
+
|
| 93 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
| 94 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
| 95 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
| 96 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
| 97 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
| 98 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
| 99 |
+
|
| 100 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
| 101 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
| 102 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
| 103 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
| 104 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
| 105 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
| 106 |
+
|
| 107 |
+
const int batch = value.size(0);
|
| 108 |
+
const int spatial_size = value.size(1);
|
| 109 |
+
const int num_heads = value.size(2);
|
| 110 |
+
const int channels = value.size(3);
|
| 111 |
+
|
| 112 |
+
const int num_levels = spatial_shapes.size(0);
|
| 113 |
+
|
| 114 |
+
const int num_query = sampling_loc.size(1);
|
| 115 |
+
const int num_point = sampling_loc.size(4);
|
| 116 |
+
|
| 117 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
| 118 |
+
|
| 119 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
| 120 |
+
|
| 121 |
+
auto grad_value = at::zeros_like(value);
|
| 122 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
| 123 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
| 124 |
+
|
| 125 |
+
const int batch_n = im2col_step_;
|
| 126 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
| 127 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
| 128 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
| 129 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
| 130 |
+
|
| 131 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
| 132 |
+
{
|
| 133 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
| 134 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
| 135 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
| 136 |
+
grad_output_g.data<scalar_t>(),
|
| 137 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 138 |
+
spatial_shapes.data<int64_t>(),
|
| 139 |
+
level_start_index.data<int64_t>(),
|
| 140 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 141 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
| 142 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
| 143 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
| 144 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
| 145 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
| 146 |
+
|
| 147 |
+
}));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
| 152 |
+
};
|
| 153 |
+
}
|
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
|
| 14 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
| 15 |
+
const at::Tensor &value,
|
| 16 |
+
const at::Tensor &spatial_shapes,
|
| 17 |
+
const at::Tensor &level_start_index,
|
| 18 |
+
const at::Tensor &sampling_loc,
|
| 19 |
+
const at::Tensor &attn_weight,
|
| 20 |
+
const int im2col_step);
|
| 21 |
+
|
| 22 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
| 23 |
+
const at::Tensor &value,
|
| 24 |
+
const at::Tensor &spatial_shapes,
|
| 25 |
+
const at::Tensor &level_start_index,
|
| 26 |
+
const at::Tensor &sampling_loc,
|
| 27 |
+
const at::Tensor &attn_weight,
|
| 28 |
+
const at::Tensor &grad_output,
|
| 29 |
+
const int im2col_step);
|
| 30 |
+
|
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************
|
| 7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
| 8 |
+
* Copyright (c) 2018 Microsoft
|
| 9 |
+
**************************************************************************
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <cstring>
|
| 15 |
+
|
| 16 |
+
#include <ATen/ATen.h>
|
| 17 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 18 |
+
|
| 19 |
+
#include <THC/THCAtomics.cuh>
|
| 20 |
+
|
| 21 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 22 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 23 |
+
i < (n); \
|
| 24 |
+
i += blockDim.x * gridDim.x)
|
| 25 |
+
|
| 26 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 27 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
| 28 |
+
{
|
| 29 |
+
return (N + num_threads - 1) / num_threads;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
template <typename scalar_t>
|
| 34 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
| 35 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 36 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
| 37 |
+
{
|
| 38 |
+
const int h_low = floor(h);
|
| 39 |
+
const int w_low = floor(w);
|
| 40 |
+
const int h_high = h_low + 1;
|
| 41 |
+
const int w_high = w_low + 1;
|
| 42 |
+
|
| 43 |
+
const scalar_t lh = h - h_low;
|
| 44 |
+
const scalar_t lw = w - w_low;
|
| 45 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 46 |
+
|
| 47 |
+
const int w_stride = nheads * channels;
|
| 48 |
+
const int h_stride = width * w_stride;
|
| 49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 53 |
+
const int base_ptr = m * channels + c;
|
| 54 |
+
|
| 55 |
+
scalar_t v1 = 0;
|
| 56 |
+
if (h_low >= 0 && w_low >= 0)
|
| 57 |
+
{
|
| 58 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 59 |
+
v1 = bottom_data[ptr1];
|
| 60 |
+
}
|
| 61 |
+
scalar_t v2 = 0;
|
| 62 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 63 |
+
{
|
| 64 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 65 |
+
v2 = bottom_data[ptr2];
|
| 66 |
+
}
|
| 67 |
+
scalar_t v3 = 0;
|
| 68 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 69 |
+
{
|
| 70 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 71 |
+
v3 = bottom_data[ptr3];
|
| 72 |
+
}
|
| 73 |
+
scalar_t v4 = 0;
|
| 74 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 75 |
+
{
|
| 76 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 77 |
+
v4 = bottom_data[ptr4];
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 81 |
+
|
| 82 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 83 |
+
return val;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
template <typename scalar_t>
|
| 88 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
| 89 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 90 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 91 |
+
const scalar_t &top_grad,
|
| 92 |
+
const scalar_t &attn_weight,
|
| 93 |
+
scalar_t* &grad_value,
|
| 94 |
+
scalar_t* grad_sampling_loc,
|
| 95 |
+
scalar_t* grad_attn_weight)
|
| 96 |
+
{
|
| 97 |
+
const int h_low = floor(h);
|
| 98 |
+
const int w_low = floor(w);
|
| 99 |
+
const int h_high = h_low + 1;
|
| 100 |
+
const int w_high = w_low + 1;
|
| 101 |
+
|
| 102 |
+
const scalar_t lh = h - h_low;
|
| 103 |
+
const scalar_t lw = w - w_low;
|
| 104 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 105 |
+
|
| 106 |
+
const int w_stride = nheads * channels;
|
| 107 |
+
const int h_stride = width * w_stride;
|
| 108 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 109 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 110 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 111 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 112 |
+
const int base_ptr = m * channels + c;
|
| 113 |
+
|
| 114 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 115 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 116 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 117 |
+
|
| 118 |
+
scalar_t v1 = 0;
|
| 119 |
+
if (h_low >= 0 && w_low >= 0)
|
| 120 |
+
{
|
| 121 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 122 |
+
v1 = bottom_data[ptr1];
|
| 123 |
+
grad_h_weight -= hw * v1;
|
| 124 |
+
grad_w_weight -= hh * v1;
|
| 125 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 126 |
+
}
|
| 127 |
+
scalar_t v2 = 0;
|
| 128 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 129 |
+
{
|
| 130 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 131 |
+
v2 = bottom_data[ptr2];
|
| 132 |
+
grad_h_weight -= lw * v2;
|
| 133 |
+
grad_w_weight += hh * v2;
|
| 134 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 135 |
+
}
|
| 136 |
+
scalar_t v3 = 0;
|
| 137 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 138 |
+
{
|
| 139 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 140 |
+
v3 = bottom_data[ptr3];
|
| 141 |
+
grad_h_weight += hw * v3;
|
| 142 |
+
grad_w_weight -= lh * v3;
|
| 143 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 144 |
+
}
|
| 145 |
+
scalar_t v4 = 0;
|
| 146 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 147 |
+
{
|
| 148 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 149 |
+
v4 = bottom_data[ptr4];
|
| 150 |
+
grad_h_weight += lw * v4;
|
| 151 |
+
grad_w_weight += lh * v4;
|
| 152 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 156 |
+
*grad_attn_weight = top_grad * val;
|
| 157 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
| 158 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
| 164 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
| 165 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
| 166 |
+
const scalar_t &top_grad,
|
| 167 |
+
const scalar_t &attn_weight,
|
| 168 |
+
scalar_t* &grad_value,
|
| 169 |
+
scalar_t* grad_sampling_loc,
|
| 170 |
+
scalar_t* grad_attn_weight)
|
| 171 |
+
{
|
| 172 |
+
const int h_low = floor(h);
|
| 173 |
+
const int w_low = floor(w);
|
| 174 |
+
const int h_high = h_low + 1;
|
| 175 |
+
const int w_high = w_low + 1;
|
| 176 |
+
|
| 177 |
+
const scalar_t lh = h - h_low;
|
| 178 |
+
const scalar_t lw = w - w_low;
|
| 179 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 180 |
+
|
| 181 |
+
const int w_stride = nheads * channels;
|
| 182 |
+
const int h_stride = width * w_stride;
|
| 183 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
| 184 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
| 185 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
| 186 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
| 187 |
+
const int base_ptr = m * channels + c;
|
| 188 |
+
|
| 189 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 190 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
| 191 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
| 192 |
+
|
| 193 |
+
scalar_t v1 = 0;
|
| 194 |
+
if (h_low >= 0 && w_low >= 0)
|
| 195 |
+
{
|
| 196 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 197 |
+
v1 = bottom_data[ptr1];
|
| 198 |
+
grad_h_weight -= hw * v1;
|
| 199 |
+
grad_w_weight -= hh * v1;
|
| 200 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
| 201 |
+
}
|
| 202 |
+
scalar_t v2 = 0;
|
| 203 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 204 |
+
{
|
| 205 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 206 |
+
v2 = bottom_data[ptr2];
|
| 207 |
+
grad_h_weight -= lw * v2;
|
| 208 |
+
grad_w_weight += hh * v2;
|
| 209 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
| 210 |
+
}
|
| 211 |
+
scalar_t v3 = 0;
|
| 212 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 213 |
+
{
|
| 214 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
| 215 |
+
v3 = bottom_data[ptr3];
|
| 216 |
+
grad_h_weight += hw * v3;
|
| 217 |
+
grad_w_weight -= lh * v3;
|
| 218 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
| 219 |
+
}
|
| 220 |
+
scalar_t v4 = 0;
|
| 221 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 222 |
+
{
|
| 223 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
| 224 |
+
v4 = bottom_data[ptr4];
|
| 225 |
+
grad_h_weight += lw * v4;
|
| 226 |
+
grad_w_weight += lh * v4;
|
| 227 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 231 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
| 232 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
| 233 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
template <typename scalar_t>
|
| 238 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
| 239 |
+
const scalar_t *data_value,
|
| 240 |
+
const int64_t *data_spatial_shapes,
|
| 241 |
+
const int64_t *data_level_start_index,
|
| 242 |
+
const scalar_t *data_sampling_loc,
|
| 243 |
+
const scalar_t *data_attn_weight,
|
| 244 |
+
const int batch_size,
|
| 245 |
+
const int spatial_size,
|
| 246 |
+
const int num_heads,
|
| 247 |
+
const int channels,
|
| 248 |
+
const int num_levels,
|
| 249 |
+
const int num_query,
|
| 250 |
+
const int num_point,
|
| 251 |
+
scalar_t *data_col)
|
| 252 |
+
{
|
| 253 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 254 |
+
{
|
| 255 |
+
int _temp = index;
|
| 256 |
+
const int c_col = _temp % channels;
|
| 257 |
+
_temp /= channels;
|
| 258 |
+
const int sampling_index = _temp;
|
| 259 |
+
const int m_col = _temp % num_heads;
|
| 260 |
+
_temp /= num_heads;
|
| 261 |
+
const int q_col = _temp % num_query;
|
| 262 |
+
_temp /= num_query;
|
| 263 |
+
const int b_col = _temp;
|
| 264 |
+
|
| 265 |
+
scalar_t *data_col_ptr = data_col + index;
|
| 266 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 267 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 268 |
+
const int qid_stride = num_heads * channels;
|
| 269 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 270 |
+
scalar_t col = 0;
|
| 271 |
+
|
| 272 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 273 |
+
{
|
| 274 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 275 |
+
const int spatial_h_ptr = l_col << 1;
|
| 276 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 277 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 278 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
| 279 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 280 |
+
{
|
| 281 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 282 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 283 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 284 |
+
|
| 285 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 286 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 287 |
+
|
| 288 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 289 |
+
{
|
| 290 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
data_weight_ptr += 1;
|
| 294 |
+
data_loc_w_ptr += 2;
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
*data_col_ptr = col;
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 302 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
| 303 |
+
const scalar_t *grad_col,
|
| 304 |
+
const scalar_t *data_value,
|
| 305 |
+
const int64_t *data_spatial_shapes,
|
| 306 |
+
const int64_t *data_level_start_index,
|
| 307 |
+
const scalar_t *data_sampling_loc,
|
| 308 |
+
const scalar_t *data_attn_weight,
|
| 309 |
+
const int batch_size,
|
| 310 |
+
const int spatial_size,
|
| 311 |
+
const int num_heads,
|
| 312 |
+
const int channels,
|
| 313 |
+
const int num_levels,
|
| 314 |
+
const int num_query,
|
| 315 |
+
const int num_point,
|
| 316 |
+
scalar_t *grad_value,
|
| 317 |
+
scalar_t *grad_sampling_loc,
|
| 318 |
+
scalar_t *grad_attn_weight)
|
| 319 |
+
{
|
| 320 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 321 |
+
{
|
| 322 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 323 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 324 |
+
unsigned int tid = threadIdx.x;
|
| 325 |
+
int _temp = index;
|
| 326 |
+
const int c_col = _temp % channels;
|
| 327 |
+
_temp /= channels;
|
| 328 |
+
const int sampling_index = _temp;
|
| 329 |
+
const int m_col = _temp % num_heads;
|
| 330 |
+
_temp /= num_heads;
|
| 331 |
+
const int q_col = _temp % num_query;
|
| 332 |
+
_temp /= num_query;
|
| 333 |
+
const int b_col = _temp;
|
| 334 |
+
|
| 335 |
+
const scalar_t top_grad = grad_col[index];
|
| 336 |
+
|
| 337 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 338 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 339 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 340 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 341 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 342 |
+
const int grad_weight_stride = 1;
|
| 343 |
+
const int grad_loc_stride = 2;
|
| 344 |
+
const int qid_stride = num_heads * channels;
|
| 345 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 346 |
+
|
| 347 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 348 |
+
{
|
| 349 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 350 |
+
const int spatial_h_ptr = l_col << 1;
|
| 351 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 352 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 353 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 354 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 355 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 356 |
+
|
| 357 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 358 |
+
{
|
| 359 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 360 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 361 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 362 |
+
|
| 363 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 364 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 365 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 366 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 367 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 368 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 369 |
+
{
|
| 370 |
+
ms_deform_attn_col2im_bilinear(
|
| 371 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 372 |
+
top_grad, weight, grad_value_ptr,
|
| 373 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
__syncthreads();
|
| 377 |
+
if (tid == 0)
|
| 378 |
+
{
|
| 379 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 380 |
+
int sid=2;
|
| 381 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
| 382 |
+
{
|
| 383 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 384 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 385 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 386 |
+
sid += 2;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
*grad_sampling_loc = _grad_w;
|
| 391 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 392 |
+
*grad_attn_weight = _grad_a;
|
| 393 |
+
}
|
| 394 |
+
__syncthreads();
|
| 395 |
+
|
| 396 |
+
data_weight_ptr += 1;
|
| 397 |
+
data_loc_w_ptr += 2;
|
| 398 |
+
grad_attn_weight += grad_weight_stride;
|
| 399 |
+
grad_sampling_loc += grad_loc_stride;
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
template <typename scalar_t, unsigned int blockSize>
|
| 407 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
| 408 |
+
const scalar_t *grad_col,
|
| 409 |
+
const scalar_t *data_value,
|
| 410 |
+
const int64_t *data_spatial_shapes,
|
| 411 |
+
const int64_t *data_level_start_index,
|
| 412 |
+
const scalar_t *data_sampling_loc,
|
| 413 |
+
const scalar_t *data_attn_weight,
|
| 414 |
+
const int batch_size,
|
| 415 |
+
const int spatial_size,
|
| 416 |
+
const int num_heads,
|
| 417 |
+
const int channels,
|
| 418 |
+
const int num_levels,
|
| 419 |
+
const int num_query,
|
| 420 |
+
const int num_point,
|
| 421 |
+
scalar_t *grad_value,
|
| 422 |
+
scalar_t *grad_sampling_loc,
|
| 423 |
+
scalar_t *grad_attn_weight)
|
| 424 |
+
{
|
| 425 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 426 |
+
{
|
| 427 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
| 428 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
| 429 |
+
unsigned int tid = threadIdx.x;
|
| 430 |
+
int _temp = index;
|
| 431 |
+
const int c_col = _temp % channels;
|
| 432 |
+
_temp /= channels;
|
| 433 |
+
const int sampling_index = _temp;
|
| 434 |
+
const int m_col = _temp % num_heads;
|
| 435 |
+
_temp /= num_heads;
|
| 436 |
+
const int q_col = _temp % num_query;
|
| 437 |
+
_temp /= num_query;
|
| 438 |
+
const int b_col = _temp;
|
| 439 |
+
|
| 440 |
+
const scalar_t top_grad = grad_col[index];
|
| 441 |
+
|
| 442 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 443 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 444 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 445 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 446 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 447 |
+
const int grad_weight_stride = 1;
|
| 448 |
+
const int grad_loc_stride = 2;
|
| 449 |
+
const int qid_stride = num_heads * channels;
|
| 450 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 451 |
+
|
| 452 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 453 |
+
{
|
| 454 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 455 |
+
const int spatial_h_ptr = l_col << 1;
|
| 456 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 457 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 458 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 459 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 460 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 461 |
+
|
| 462 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 463 |
+
{
|
| 464 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 465 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 466 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 467 |
+
|
| 468 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 469 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 470 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 471 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 472 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 473 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 474 |
+
{
|
| 475 |
+
ms_deform_attn_col2im_bilinear(
|
| 476 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 477 |
+
top_grad, weight, grad_value_ptr,
|
| 478 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
__syncthreads();
|
| 482 |
+
|
| 483 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
| 484 |
+
{
|
| 485 |
+
if (tid < s) {
|
| 486 |
+
const unsigned int xid1 = tid << 1;
|
| 487 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 488 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 489 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 490 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 491 |
+
}
|
| 492 |
+
__syncthreads();
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
if (tid == 0)
|
| 496 |
+
{
|
| 497 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 498 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 499 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 500 |
+
}
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
data_weight_ptr += 1;
|
| 504 |
+
data_loc_w_ptr += 2;
|
| 505 |
+
grad_attn_weight += grad_weight_stride;
|
| 506 |
+
grad_sampling_loc += grad_loc_stride;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
template <typename scalar_t>
|
| 514 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
| 515 |
+
const scalar_t *grad_col,
|
| 516 |
+
const scalar_t *data_value,
|
| 517 |
+
const int64_t *data_spatial_shapes,
|
| 518 |
+
const int64_t *data_level_start_index,
|
| 519 |
+
const scalar_t *data_sampling_loc,
|
| 520 |
+
const scalar_t *data_attn_weight,
|
| 521 |
+
const int batch_size,
|
| 522 |
+
const int spatial_size,
|
| 523 |
+
const int num_heads,
|
| 524 |
+
const int channels,
|
| 525 |
+
const int num_levels,
|
| 526 |
+
const int num_query,
|
| 527 |
+
const int num_point,
|
| 528 |
+
scalar_t *grad_value,
|
| 529 |
+
scalar_t *grad_sampling_loc,
|
| 530 |
+
scalar_t *grad_attn_weight)
|
| 531 |
+
{
|
| 532 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 533 |
+
{
|
| 534 |
+
extern __shared__ int _s[];
|
| 535 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 536 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 537 |
+
unsigned int tid = threadIdx.x;
|
| 538 |
+
int _temp = index;
|
| 539 |
+
const int c_col = _temp % channels;
|
| 540 |
+
_temp /= channels;
|
| 541 |
+
const int sampling_index = _temp;
|
| 542 |
+
const int m_col = _temp % num_heads;
|
| 543 |
+
_temp /= num_heads;
|
| 544 |
+
const int q_col = _temp % num_query;
|
| 545 |
+
_temp /= num_query;
|
| 546 |
+
const int b_col = _temp;
|
| 547 |
+
|
| 548 |
+
const scalar_t top_grad = grad_col[index];
|
| 549 |
+
|
| 550 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 551 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 552 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 553 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 554 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 555 |
+
const int grad_weight_stride = 1;
|
| 556 |
+
const int grad_loc_stride = 2;
|
| 557 |
+
const int qid_stride = num_heads * channels;
|
| 558 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 559 |
+
|
| 560 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 561 |
+
{
|
| 562 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 563 |
+
const int spatial_h_ptr = l_col << 1;
|
| 564 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 565 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 566 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 567 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 568 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 569 |
+
|
| 570 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 571 |
+
{
|
| 572 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 573 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 574 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 575 |
+
|
| 576 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 577 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 578 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 579 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 580 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 581 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 582 |
+
{
|
| 583 |
+
ms_deform_attn_col2im_bilinear(
|
| 584 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 585 |
+
top_grad, weight, grad_value_ptr,
|
| 586 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
__syncthreads();
|
| 590 |
+
if (tid == 0)
|
| 591 |
+
{
|
| 592 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
| 593 |
+
int sid=2;
|
| 594 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
| 595 |
+
{
|
| 596 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
| 597 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
| 598 |
+
_grad_a += cache_grad_attn_weight[tid];
|
| 599 |
+
sid += 2;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
*grad_sampling_loc = _grad_w;
|
| 604 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
| 605 |
+
*grad_attn_weight = _grad_a;
|
| 606 |
+
}
|
| 607 |
+
__syncthreads();
|
| 608 |
+
|
| 609 |
+
data_weight_ptr += 1;
|
| 610 |
+
data_loc_w_ptr += 2;
|
| 611 |
+
grad_attn_weight += grad_weight_stride;
|
| 612 |
+
grad_sampling_loc += grad_loc_stride;
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
template <typename scalar_t>
|
| 619 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
| 620 |
+
const scalar_t *grad_col,
|
| 621 |
+
const scalar_t *data_value,
|
| 622 |
+
const int64_t *data_spatial_shapes,
|
| 623 |
+
const int64_t *data_level_start_index,
|
| 624 |
+
const scalar_t *data_sampling_loc,
|
| 625 |
+
const scalar_t *data_attn_weight,
|
| 626 |
+
const int batch_size,
|
| 627 |
+
const int spatial_size,
|
| 628 |
+
const int num_heads,
|
| 629 |
+
const int channels,
|
| 630 |
+
const int num_levels,
|
| 631 |
+
const int num_query,
|
| 632 |
+
const int num_point,
|
| 633 |
+
scalar_t *grad_value,
|
| 634 |
+
scalar_t *grad_sampling_loc,
|
| 635 |
+
scalar_t *grad_attn_weight)
|
| 636 |
+
{
|
| 637 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 638 |
+
{
|
| 639 |
+
extern __shared__ int _s[];
|
| 640 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 641 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 642 |
+
unsigned int tid = threadIdx.x;
|
| 643 |
+
int _temp = index;
|
| 644 |
+
const int c_col = _temp % channels;
|
| 645 |
+
_temp /= channels;
|
| 646 |
+
const int sampling_index = _temp;
|
| 647 |
+
const int m_col = _temp % num_heads;
|
| 648 |
+
_temp /= num_heads;
|
| 649 |
+
const int q_col = _temp % num_query;
|
| 650 |
+
_temp /= num_query;
|
| 651 |
+
const int b_col = _temp;
|
| 652 |
+
|
| 653 |
+
const scalar_t top_grad = grad_col[index];
|
| 654 |
+
|
| 655 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 656 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 657 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 658 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 659 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 660 |
+
const int grad_weight_stride = 1;
|
| 661 |
+
const int grad_loc_stride = 2;
|
| 662 |
+
const int qid_stride = num_heads * channels;
|
| 663 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 664 |
+
|
| 665 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 666 |
+
{
|
| 667 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 668 |
+
const int spatial_h_ptr = l_col << 1;
|
| 669 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 670 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 671 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 672 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 673 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 674 |
+
|
| 675 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 676 |
+
{
|
| 677 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 678 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 679 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 680 |
+
|
| 681 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 682 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 683 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 684 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 685 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 686 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 687 |
+
{
|
| 688 |
+
ms_deform_attn_col2im_bilinear(
|
| 689 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 690 |
+
top_grad, weight, grad_value_ptr,
|
| 691 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
__syncthreads();
|
| 695 |
+
|
| 696 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 697 |
+
{
|
| 698 |
+
if (tid < s) {
|
| 699 |
+
const unsigned int xid1 = tid << 1;
|
| 700 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 701 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 702 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 703 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 704 |
+
if (tid + (s << 1) < spre)
|
| 705 |
+
{
|
| 706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
__syncthreads();
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
if (tid == 0)
|
| 715 |
+
{
|
| 716 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
| 717 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
| 718 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
| 719 |
+
}
|
| 720 |
+
__syncthreads();
|
| 721 |
+
|
| 722 |
+
data_weight_ptr += 1;
|
| 723 |
+
data_loc_w_ptr += 2;
|
| 724 |
+
grad_attn_weight += grad_weight_stride;
|
| 725 |
+
grad_sampling_loc += grad_loc_stride;
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
template <typename scalar_t>
|
| 732 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
| 733 |
+
const scalar_t *grad_col,
|
| 734 |
+
const scalar_t *data_value,
|
| 735 |
+
const int64_t *data_spatial_shapes,
|
| 736 |
+
const int64_t *data_level_start_index,
|
| 737 |
+
const scalar_t *data_sampling_loc,
|
| 738 |
+
const scalar_t *data_attn_weight,
|
| 739 |
+
const int batch_size,
|
| 740 |
+
const int spatial_size,
|
| 741 |
+
const int num_heads,
|
| 742 |
+
const int channels,
|
| 743 |
+
const int num_levels,
|
| 744 |
+
const int num_query,
|
| 745 |
+
const int num_point,
|
| 746 |
+
scalar_t *grad_value,
|
| 747 |
+
scalar_t *grad_sampling_loc,
|
| 748 |
+
scalar_t *grad_attn_weight)
|
| 749 |
+
{
|
| 750 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 751 |
+
{
|
| 752 |
+
extern __shared__ int _s[];
|
| 753 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
| 754 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
| 755 |
+
unsigned int tid = threadIdx.x;
|
| 756 |
+
int _temp = index;
|
| 757 |
+
const int c_col = _temp % channels;
|
| 758 |
+
_temp /= channels;
|
| 759 |
+
const int sampling_index = _temp;
|
| 760 |
+
const int m_col = _temp % num_heads;
|
| 761 |
+
_temp /= num_heads;
|
| 762 |
+
const int q_col = _temp % num_query;
|
| 763 |
+
_temp /= num_query;
|
| 764 |
+
const int b_col = _temp;
|
| 765 |
+
|
| 766 |
+
const scalar_t top_grad = grad_col[index];
|
| 767 |
+
|
| 768 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 769 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 770 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 771 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 772 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 773 |
+
const int grad_weight_stride = 1;
|
| 774 |
+
const int grad_loc_stride = 2;
|
| 775 |
+
const int qid_stride = num_heads * channels;
|
| 776 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 777 |
+
|
| 778 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 779 |
+
{
|
| 780 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 781 |
+
const int spatial_h_ptr = l_col << 1;
|
| 782 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 783 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 784 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 785 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 786 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 787 |
+
|
| 788 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 789 |
+
{
|
| 790 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 791 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 792 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 793 |
+
|
| 794 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 795 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 796 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
| 797 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
| 798 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
| 799 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 800 |
+
{
|
| 801 |
+
ms_deform_attn_col2im_bilinear(
|
| 802 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 803 |
+
top_grad, weight, grad_value_ptr,
|
| 804 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
__syncthreads();
|
| 808 |
+
|
| 809 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
| 810 |
+
{
|
| 811 |
+
if (tid < s) {
|
| 812 |
+
const unsigned int xid1 = tid << 1;
|
| 813 |
+
const unsigned int xid2 = (tid + s) << 1;
|
| 814 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
| 815 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
| 816 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
| 817 |
+
if (tid + (s << 1) < spre)
|
| 818 |
+
{
|
| 819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
| 820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
| 821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
| 822 |
+
}
|
| 823 |
+
}
|
| 824 |
+
__syncthreads();
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
if (tid == 0)
|
| 828 |
+
{
|
| 829 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
| 830 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
| 831 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
| 832 |
+
}
|
| 833 |
+
__syncthreads();
|
| 834 |
+
|
| 835 |
+
data_weight_ptr += 1;
|
| 836 |
+
data_loc_w_ptr += 2;
|
| 837 |
+
grad_attn_weight += grad_weight_stride;
|
| 838 |
+
grad_sampling_loc += grad_loc_stride;
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
}
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
template <typename scalar_t>
|
| 846 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
| 847 |
+
const scalar_t *grad_col,
|
| 848 |
+
const scalar_t *data_value,
|
| 849 |
+
const int64_t *data_spatial_shapes,
|
| 850 |
+
const int64_t *data_level_start_index,
|
| 851 |
+
const scalar_t *data_sampling_loc,
|
| 852 |
+
const scalar_t *data_attn_weight,
|
| 853 |
+
const int batch_size,
|
| 854 |
+
const int spatial_size,
|
| 855 |
+
const int num_heads,
|
| 856 |
+
const int channels,
|
| 857 |
+
const int num_levels,
|
| 858 |
+
const int num_query,
|
| 859 |
+
const int num_point,
|
| 860 |
+
scalar_t *grad_value,
|
| 861 |
+
scalar_t *grad_sampling_loc,
|
| 862 |
+
scalar_t *grad_attn_weight)
|
| 863 |
+
{
|
| 864 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 865 |
+
{
|
| 866 |
+
int _temp = index;
|
| 867 |
+
const int c_col = _temp % channels;
|
| 868 |
+
_temp /= channels;
|
| 869 |
+
const int sampling_index = _temp;
|
| 870 |
+
const int m_col = _temp % num_heads;
|
| 871 |
+
_temp /= num_heads;
|
| 872 |
+
const int q_col = _temp % num_query;
|
| 873 |
+
_temp /= num_query;
|
| 874 |
+
const int b_col = _temp;
|
| 875 |
+
|
| 876 |
+
const scalar_t top_grad = grad_col[index];
|
| 877 |
+
|
| 878 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
| 879 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
| 880 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
| 881 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
| 882 |
+
grad_attn_weight += grad_sampling_ptr;
|
| 883 |
+
const int grad_weight_stride = 1;
|
| 884 |
+
const int grad_loc_stride = 2;
|
| 885 |
+
const int qid_stride = num_heads * channels;
|
| 886 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
| 887 |
+
|
| 888 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
| 889 |
+
{
|
| 890 |
+
const int level_start_id = data_level_start_index[l_col];
|
| 891 |
+
const int spatial_h_ptr = l_col << 1;
|
| 892 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
| 893 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
| 894 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
| 895 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
| 896 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
| 897 |
+
|
| 898 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
| 899 |
+
{
|
| 900 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
| 901 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
| 902 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
| 903 |
+
|
| 904 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
| 905 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
| 906 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
| 907 |
+
{
|
| 908 |
+
ms_deform_attn_col2im_bilinear_gm(
|
| 909 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
| 910 |
+
top_grad, weight, grad_value_ptr,
|
| 911 |
+
grad_sampling_loc, grad_attn_weight);
|
| 912 |
+
}
|
| 913 |
+
data_weight_ptr += 1;
|
| 914 |
+
data_loc_w_ptr += 2;
|
| 915 |
+
grad_attn_weight += grad_weight_stride;
|
| 916 |
+
grad_sampling_loc += grad_loc_stride;
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
template <typename scalar_t>
|
| 924 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
| 925 |
+
const scalar_t* data_value,
|
| 926 |
+
const int64_t* data_spatial_shapes,
|
| 927 |
+
const int64_t* data_level_start_index,
|
| 928 |
+
const scalar_t* data_sampling_loc,
|
| 929 |
+
const scalar_t* data_attn_weight,
|
| 930 |
+
const int batch_size,
|
| 931 |
+
const int spatial_size,
|
| 932 |
+
const int num_heads,
|
| 933 |
+
const int channels,
|
| 934 |
+
const int num_levels,
|
| 935 |
+
const int num_query,
|
| 936 |
+
const int num_point,
|
| 937 |
+
scalar_t* data_col)
|
| 938 |
+
{
|
| 939 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 940 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 941 |
+
const int num_threads = CUDA_NUM_THREADS;
|
| 942 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
| 943 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 944 |
+
0, stream>>>(
|
| 945 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
| 946 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
| 947 |
+
|
| 948 |
+
cudaError_t err = cudaGetLastError();
|
| 949 |
+
if (err != cudaSuccess)
|
| 950 |
+
{
|
| 951 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
template <typename scalar_t>
|
| 957 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
| 958 |
+
const scalar_t* grad_col,
|
| 959 |
+
const scalar_t* data_value,
|
| 960 |
+
const int64_t * data_spatial_shapes,
|
| 961 |
+
const int64_t * data_level_start_index,
|
| 962 |
+
const scalar_t * data_sampling_loc,
|
| 963 |
+
const scalar_t * data_attn_weight,
|
| 964 |
+
const int batch_size,
|
| 965 |
+
const int spatial_size,
|
| 966 |
+
const int num_heads,
|
| 967 |
+
const int channels,
|
| 968 |
+
const int num_levels,
|
| 969 |
+
const int num_query,
|
| 970 |
+
const int num_point,
|
| 971 |
+
scalar_t* grad_value,
|
| 972 |
+
scalar_t* grad_sampling_loc,
|
| 973 |
+
scalar_t* grad_attn_weight)
|
| 974 |
+
{
|
| 975 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
| 976 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
| 977 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
| 978 |
+
if (channels > 1024)
|
| 979 |
+
{
|
| 980 |
+
if ((channels & 1023) == 0)
|
| 981 |
+
{
|
| 982 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
| 983 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 984 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 985 |
+
num_kernels,
|
| 986 |
+
grad_col,
|
| 987 |
+
data_value,
|
| 988 |
+
data_spatial_shapes,
|
| 989 |
+
data_level_start_index,
|
| 990 |
+
data_sampling_loc,
|
| 991 |
+
data_attn_weight,
|
| 992 |
+
batch_size,
|
| 993 |
+
spatial_size,
|
| 994 |
+
num_heads,
|
| 995 |
+
channels,
|
| 996 |
+
num_levels,
|
| 997 |
+
num_query,
|
| 998 |
+
num_point,
|
| 999 |
+
grad_value,
|
| 1000 |
+
grad_sampling_loc,
|
| 1001 |
+
grad_attn_weight);
|
| 1002 |
+
}
|
| 1003 |
+
else
|
| 1004 |
+
{
|
| 1005 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
| 1006 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1007 |
+
0, stream>>>(
|
| 1008 |
+
num_kernels,
|
| 1009 |
+
grad_col,
|
| 1010 |
+
data_value,
|
| 1011 |
+
data_spatial_shapes,
|
| 1012 |
+
data_level_start_index,
|
| 1013 |
+
data_sampling_loc,
|
| 1014 |
+
data_attn_weight,
|
| 1015 |
+
batch_size,
|
| 1016 |
+
spatial_size,
|
| 1017 |
+
num_heads,
|
| 1018 |
+
channels,
|
| 1019 |
+
num_levels,
|
| 1020 |
+
num_query,
|
| 1021 |
+
num_point,
|
| 1022 |
+
grad_value,
|
| 1023 |
+
grad_sampling_loc,
|
| 1024 |
+
grad_attn_weight);
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
else{
|
| 1028 |
+
switch(channels)
|
| 1029 |
+
{
|
| 1030 |
+
case 1:
|
| 1031 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
| 1032 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1033 |
+
0, stream>>>(
|
| 1034 |
+
num_kernels,
|
| 1035 |
+
grad_col,
|
| 1036 |
+
data_value,
|
| 1037 |
+
data_spatial_shapes,
|
| 1038 |
+
data_level_start_index,
|
| 1039 |
+
data_sampling_loc,
|
| 1040 |
+
data_attn_weight,
|
| 1041 |
+
batch_size,
|
| 1042 |
+
spatial_size,
|
| 1043 |
+
num_heads,
|
| 1044 |
+
channels,
|
| 1045 |
+
num_levels,
|
| 1046 |
+
num_query,
|
| 1047 |
+
num_point,
|
| 1048 |
+
grad_value,
|
| 1049 |
+
grad_sampling_loc,
|
| 1050 |
+
grad_attn_weight);
|
| 1051 |
+
break;
|
| 1052 |
+
case 2:
|
| 1053 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
| 1054 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1055 |
+
0, stream>>>(
|
| 1056 |
+
num_kernels,
|
| 1057 |
+
grad_col,
|
| 1058 |
+
data_value,
|
| 1059 |
+
data_spatial_shapes,
|
| 1060 |
+
data_level_start_index,
|
| 1061 |
+
data_sampling_loc,
|
| 1062 |
+
data_attn_weight,
|
| 1063 |
+
batch_size,
|
| 1064 |
+
spatial_size,
|
| 1065 |
+
num_heads,
|
| 1066 |
+
channels,
|
| 1067 |
+
num_levels,
|
| 1068 |
+
num_query,
|
| 1069 |
+
num_point,
|
| 1070 |
+
grad_value,
|
| 1071 |
+
grad_sampling_loc,
|
| 1072 |
+
grad_attn_weight);
|
| 1073 |
+
break;
|
| 1074 |
+
case 4:
|
| 1075 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
| 1076 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1077 |
+
0, stream>>>(
|
| 1078 |
+
num_kernels,
|
| 1079 |
+
grad_col,
|
| 1080 |
+
data_value,
|
| 1081 |
+
data_spatial_shapes,
|
| 1082 |
+
data_level_start_index,
|
| 1083 |
+
data_sampling_loc,
|
| 1084 |
+
data_attn_weight,
|
| 1085 |
+
batch_size,
|
| 1086 |
+
spatial_size,
|
| 1087 |
+
num_heads,
|
| 1088 |
+
channels,
|
| 1089 |
+
num_levels,
|
| 1090 |
+
num_query,
|
| 1091 |
+
num_point,
|
| 1092 |
+
grad_value,
|
| 1093 |
+
grad_sampling_loc,
|
| 1094 |
+
grad_attn_weight);
|
| 1095 |
+
break;
|
| 1096 |
+
case 8:
|
| 1097 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
| 1098 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1099 |
+
0, stream>>>(
|
| 1100 |
+
num_kernels,
|
| 1101 |
+
grad_col,
|
| 1102 |
+
data_value,
|
| 1103 |
+
data_spatial_shapes,
|
| 1104 |
+
data_level_start_index,
|
| 1105 |
+
data_sampling_loc,
|
| 1106 |
+
data_attn_weight,
|
| 1107 |
+
batch_size,
|
| 1108 |
+
spatial_size,
|
| 1109 |
+
num_heads,
|
| 1110 |
+
channels,
|
| 1111 |
+
num_levels,
|
| 1112 |
+
num_query,
|
| 1113 |
+
num_point,
|
| 1114 |
+
grad_value,
|
| 1115 |
+
grad_sampling_loc,
|
| 1116 |
+
grad_attn_weight);
|
| 1117 |
+
break;
|
| 1118 |
+
case 16:
|
| 1119 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
| 1120 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1121 |
+
0, stream>>>(
|
| 1122 |
+
num_kernels,
|
| 1123 |
+
grad_col,
|
| 1124 |
+
data_value,
|
| 1125 |
+
data_spatial_shapes,
|
| 1126 |
+
data_level_start_index,
|
| 1127 |
+
data_sampling_loc,
|
| 1128 |
+
data_attn_weight,
|
| 1129 |
+
batch_size,
|
| 1130 |
+
spatial_size,
|
| 1131 |
+
num_heads,
|
| 1132 |
+
channels,
|
| 1133 |
+
num_levels,
|
| 1134 |
+
num_query,
|
| 1135 |
+
num_point,
|
| 1136 |
+
grad_value,
|
| 1137 |
+
grad_sampling_loc,
|
| 1138 |
+
grad_attn_weight);
|
| 1139 |
+
break;
|
| 1140 |
+
case 32:
|
| 1141 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
| 1142 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1143 |
+
0, stream>>>(
|
| 1144 |
+
num_kernels,
|
| 1145 |
+
grad_col,
|
| 1146 |
+
data_value,
|
| 1147 |
+
data_spatial_shapes,
|
| 1148 |
+
data_level_start_index,
|
| 1149 |
+
data_sampling_loc,
|
| 1150 |
+
data_attn_weight,
|
| 1151 |
+
batch_size,
|
| 1152 |
+
spatial_size,
|
| 1153 |
+
num_heads,
|
| 1154 |
+
channels,
|
| 1155 |
+
num_levels,
|
| 1156 |
+
num_query,
|
| 1157 |
+
num_point,
|
| 1158 |
+
grad_value,
|
| 1159 |
+
grad_sampling_loc,
|
| 1160 |
+
grad_attn_weight);
|
| 1161 |
+
break;
|
| 1162 |
+
case 64:
|
| 1163 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
| 1164 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1165 |
+
0, stream>>>(
|
| 1166 |
+
num_kernels,
|
| 1167 |
+
grad_col,
|
| 1168 |
+
data_value,
|
| 1169 |
+
data_spatial_shapes,
|
| 1170 |
+
data_level_start_index,
|
| 1171 |
+
data_sampling_loc,
|
| 1172 |
+
data_attn_weight,
|
| 1173 |
+
batch_size,
|
| 1174 |
+
spatial_size,
|
| 1175 |
+
num_heads,
|
| 1176 |
+
channels,
|
| 1177 |
+
num_levels,
|
| 1178 |
+
num_query,
|
| 1179 |
+
num_point,
|
| 1180 |
+
grad_value,
|
| 1181 |
+
grad_sampling_loc,
|
| 1182 |
+
grad_attn_weight);
|
| 1183 |
+
break;
|
| 1184 |
+
case 128:
|
| 1185 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
| 1186 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1187 |
+
0, stream>>>(
|
| 1188 |
+
num_kernels,
|
| 1189 |
+
grad_col,
|
| 1190 |
+
data_value,
|
| 1191 |
+
data_spatial_shapes,
|
| 1192 |
+
data_level_start_index,
|
| 1193 |
+
data_sampling_loc,
|
| 1194 |
+
data_attn_weight,
|
| 1195 |
+
batch_size,
|
| 1196 |
+
spatial_size,
|
| 1197 |
+
num_heads,
|
| 1198 |
+
channels,
|
| 1199 |
+
num_levels,
|
| 1200 |
+
num_query,
|
| 1201 |
+
num_point,
|
| 1202 |
+
grad_value,
|
| 1203 |
+
grad_sampling_loc,
|
| 1204 |
+
grad_attn_weight);
|
| 1205 |
+
break;
|
| 1206 |
+
case 256:
|
| 1207 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
| 1208 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1209 |
+
0, stream>>>(
|
| 1210 |
+
num_kernels,
|
| 1211 |
+
grad_col,
|
| 1212 |
+
data_value,
|
| 1213 |
+
data_spatial_shapes,
|
| 1214 |
+
data_level_start_index,
|
| 1215 |
+
data_sampling_loc,
|
| 1216 |
+
data_attn_weight,
|
| 1217 |
+
batch_size,
|
| 1218 |
+
spatial_size,
|
| 1219 |
+
num_heads,
|
| 1220 |
+
channels,
|
| 1221 |
+
num_levels,
|
| 1222 |
+
num_query,
|
| 1223 |
+
num_point,
|
| 1224 |
+
grad_value,
|
| 1225 |
+
grad_sampling_loc,
|
| 1226 |
+
grad_attn_weight);
|
| 1227 |
+
break;
|
| 1228 |
+
case 512:
|
| 1229 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
| 1230 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1231 |
+
0, stream>>>(
|
| 1232 |
+
num_kernels,
|
| 1233 |
+
grad_col,
|
| 1234 |
+
data_value,
|
| 1235 |
+
data_spatial_shapes,
|
| 1236 |
+
data_level_start_index,
|
| 1237 |
+
data_sampling_loc,
|
| 1238 |
+
data_attn_weight,
|
| 1239 |
+
batch_size,
|
| 1240 |
+
spatial_size,
|
| 1241 |
+
num_heads,
|
| 1242 |
+
channels,
|
| 1243 |
+
num_levels,
|
| 1244 |
+
num_query,
|
| 1245 |
+
num_point,
|
| 1246 |
+
grad_value,
|
| 1247 |
+
grad_sampling_loc,
|
| 1248 |
+
grad_attn_weight);
|
| 1249 |
+
break;
|
| 1250 |
+
case 1024:
|
| 1251 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
| 1252 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1253 |
+
0, stream>>>(
|
| 1254 |
+
num_kernels,
|
| 1255 |
+
grad_col,
|
| 1256 |
+
data_value,
|
| 1257 |
+
data_spatial_shapes,
|
| 1258 |
+
data_level_start_index,
|
| 1259 |
+
data_sampling_loc,
|
| 1260 |
+
data_attn_weight,
|
| 1261 |
+
batch_size,
|
| 1262 |
+
spatial_size,
|
| 1263 |
+
num_heads,
|
| 1264 |
+
channels,
|
| 1265 |
+
num_levels,
|
| 1266 |
+
num_query,
|
| 1267 |
+
num_point,
|
| 1268 |
+
grad_value,
|
| 1269 |
+
grad_sampling_loc,
|
| 1270 |
+
grad_attn_weight);
|
| 1271 |
+
break;
|
| 1272 |
+
default:
|
| 1273 |
+
if (channels < 64)
|
| 1274 |
+
{
|
| 1275 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
| 1276 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1277 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1278 |
+
num_kernels,
|
| 1279 |
+
grad_col,
|
| 1280 |
+
data_value,
|
| 1281 |
+
data_spatial_shapes,
|
| 1282 |
+
data_level_start_index,
|
| 1283 |
+
data_sampling_loc,
|
| 1284 |
+
data_attn_weight,
|
| 1285 |
+
batch_size,
|
| 1286 |
+
spatial_size,
|
| 1287 |
+
num_heads,
|
| 1288 |
+
channels,
|
| 1289 |
+
num_levels,
|
| 1290 |
+
num_query,
|
| 1291 |
+
num_point,
|
| 1292 |
+
grad_value,
|
| 1293 |
+
grad_sampling_loc,
|
| 1294 |
+
grad_attn_weight);
|
| 1295 |
+
}
|
| 1296 |
+
else
|
| 1297 |
+
{
|
| 1298 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
| 1299 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
| 1300 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
| 1301 |
+
num_kernels,
|
| 1302 |
+
grad_col,
|
| 1303 |
+
data_value,
|
| 1304 |
+
data_spatial_shapes,
|
| 1305 |
+
data_level_start_index,
|
| 1306 |
+
data_sampling_loc,
|
| 1307 |
+
data_attn_weight,
|
| 1308 |
+
batch_size,
|
| 1309 |
+
spatial_size,
|
| 1310 |
+
num_heads,
|
| 1311 |
+
channels,
|
| 1312 |
+
num_levels,
|
| 1313 |
+
num_query,
|
| 1314 |
+
num_point,
|
| 1315 |
+
grad_value,
|
| 1316 |
+
grad_sampling_loc,
|
| 1317 |
+
grad_attn_weight);
|
| 1318 |
+
}
|
| 1319 |
+
}
|
| 1320 |
+
}
|
| 1321 |
+
cudaError_t err = cudaGetLastError();
|
| 1322 |
+
if (err != cudaSuccess)
|
| 1323 |
+
{
|
| 1324 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 1325 |
+
}
|
| 1326 |
+
|
| 1327 |
+
}
|
perception_models/apps/detection/DETA_pe/models/ops/src/ms_deform_attn.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
| 14 |
+
|
| 15 |
+
#ifdef WITH_CUDA
|
| 16 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
at::Tensor
|
| 21 |
+
ms_deform_attn_forward(
|
| 22 |
+
const at::Tensor &value,
|
| 23 |
+
const at::Tensor &spatial_shapes,
|
| 24 |
+
const at::Tensor &level_start_index,
|
| 25 |
+
const at::Tensor &sampling_loc,
|
| 26 |
+
const at::Tensor &attn_weight,
|
| 27 |
+
const int im2col_step)
|
| 28 |
+
{
|
| 29 |
+
if (value.type().is_cuda())
|
| 30 |
+
{
|
| 31 |
+
#ifdef WITH_CUDA
|
| 32 |
+
return ms_deform_attn_cuda_forward(
|
| 33 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
| 34 |
+
#else
|
| 35 |
+
AT_ERROR("Not compiled with GPU support");
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
AT_ERROR("Not implemented on the CPU");
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
std::vector<at::Tensor>
|
| 42 |
+
ms_deform_attn_backward(
|
| 43 |
+
const at::Tensor &value,
|
| 44 |
+
const at::Tensor &spatial_shapes,
|
| 45 |
+
const at::Tensor &level_start_index,
|
| 46 |
+
const at::Tensor &sampling_loc,
|
| 47 |
+
const at::Tensor &attn_weight,
|
| 48 |
+
const at::Tensor &grad_output,
|
| 49 |
+
const int im2col_step)
|
| 50 |
+
{
|
| 51 |
+
if (value.type().is_cuda())
|
| 52 |
+
{
|
| 53 |
+
#ifdef WITH_CUDA
|
| 54 |
+
return ms_deform_attn_cuda_backward(
|
| 55 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
| 56 |
+
#else
|
| 57 |
+
AT_ERROR("Not compiled with GPU support");
|
| 58 |
+
#endif
|
| 59 |
+
}
|
| 60 |
+
AT_ERROR("Not implemented on the CPU");
|
| 61 |
+
}
|
| 62 |
+
|
perception_models/apps/detection/DETA_pe/models/ops/src/vision.cpp
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
**************************************************************************************************
|
| 3 |
+
* Deformable DETR
|
| 4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
**************************************************************************************************
|
| 7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 8 |
+
**************************************************************************************************
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include "ms_deform_attn.h"
|
| 12 |
+
|
| 13 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 14 |
+
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
| 15 |
+
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
| 16 |
+
}
|
perception_models/apps/detection/DETA_pe/models/ops/test.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
| 7 |
+
# ------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
from __future__ import division
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch.autograd import gradcheck
|
| 17 |
+
|
| 18 |
+
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
N, M, D = 1, 2, 2
|
| 22 |
+
Lq, L, P = 2, 2, 2
|
| 23 |
+
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
| 24 |
+
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
|
| 25 |
+
S = sum([(H*W).item() for H, W in shapes])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
torch.manual_seed(3)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def check_forward_equal_with_pytorch_double():
|
| 33 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
| 34 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 35 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 36 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 37 |
+
im2col_step = 2
|
| 38 |
+
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
|
| 39 |
+
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
|
| 40 |
+
fwdok = torch.allclose(output_cuda, output_pytorch)
|
| 41 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
| 42 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
| 43 |
+
|
| 44 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def check_forward_equal_with_pytorch_float():
|
| 49 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
| 50 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 51 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 52 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 53 |
+
im2col_step = 2
|
| 54 |
+
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
|
| 55 |
+
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
|
| 56 |
+
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
| 57 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
| 58 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
| 59 |
+
|
| 60 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
|
| 64 |
+
|
| 65 |
+
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
| 66 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
| 67 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
| 68 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
| 69 |
+
im2col_step = 2
|
| 70 |
+
func = MSDeformAttnFunction.apply
|
| 71 |
+
|
| 72 |
+
value.requires_grad = grad_value
|
| 73 |
+
sampling_locations.requires_grad = grad_sampling_loc
|
| 74 |
+
attention_weights.requires_grad = grad_attn_weight
|
| 75 |
+
|
| 76 |
+
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
|
| 77 |
+
|
| 78 |
+
print(f'* {gradok} check_gradient_numerical(D={channels})')
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == '__main__':
|
| 82 |
+
check_forward_equal_with_pytorch_double()
|
| 83 |
+
check_forward_equal_with_pytorch_float()
|
| 84 |
+
|
| 85 |
+
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
| 86 |
+
check_gradient_numerical(channels, True, True, True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
perception_models/apps/detection/DETA_pe/models/pev1.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from torch import broadcast_tensors, einsum, nn
|
| 10 |
+
from torch.nn.parameter import Parameter
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
|
| 13 |
+
from .utils_d2 import (
|
| 14 |
+
add_decomposed_rel_pos,
|
| 15 |
+
PatchEmbed,
|
| 16 |
+
window_partition,
|
| 17 |
+
window_unpartition,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_abs_pos(abs_pos, has_cls_token, hw, tile=False):
|
| 22 |
+
h, w = hw
|
| 23 |
+
if has_cls_token:
|
| 24 |
+
abs_pos = abs_pos[:, 1:]
|
| 25 |
+
xy_num = abs_pos.shape[1]
|
| 26 |
+
size = int(math.sqrt(xy_num))
|
| 27 |
+
assert size * size == xy_num
|
| 28 |
+
|
| 29 |
+
if size != h or size != w:
|
| 30 |
+
if tile == True:
|
| 31 |
+
new_abs_pos = abs_pos.reshape(1, size, size, -1).tile(
|
| 32 |
+
[1, h // size + 1, w // size + 1, 1]
|
| 33 |
+
)[:, :h, :w, :]
|
| 34 |
+
|
| 35 |
+
return new_abs_pos
|
| 36 |
+
else:
|
| 37 |
+
new_abs_pos = F.interpolate(
|
| 38 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
| 39 |
+
size=(h, w),
|
| 40 |
+
mode="bicubic",
|
| 41 |
+
align_corners=False,
|
| 42 |
+
)
|
| 43 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
| 44 |
+
else:
|
| 45 |
+
return abs_pos.reshape(1, h, w, -1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# broadcat, as tortoise-tts was using it
|
| 49 |
+
def broadcat(tensors, dim=-1):
|
| 50 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 51 |
+
return torch.cat(broadcasted_tensors, dim=dim)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# rotary embedding helper functions
|
| 55 |
+
def rotate_half(x):
|
| 56 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 57 |
+
x1, x2 = x.unbind(dim=-1)
|
| 58 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 59 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
dim,
|
| 66 |
+
pt_seq_len=16,
|
| 67 |
+
ft_seq_len=None,
|
| 68 |
+
custom_freqs=None,
|
| 69 |
+
freqs_for="lang",
|
| 70 |
+
theta=10000,
|
| 71 |
+
max_freq=10,
|
| 72 |
+
num_freqs=1,
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
if custom_freqs:
|
| 76 |
+
freqs = custom_freqs
|
| 77 |
+
elif freqs_for == "lang":
|
| 78 |
+
freqs = 1.0 / (
|
| 79 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 80 |
+
)
|
| 81 |
+
elif freqs_for == "pixel":
|
| 82 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 83 |
+
elif freqs_for == "constant":
|
| 84 |
+
freqs = torch.ones(num_freqs).float()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 87 |
+
|
| 88 |
+
if ft_seq_len is None:
|
| 89 |
+
ft_seq_len = pt_seq_len
|
| 90 |
+
t = (
|
| 91 |
+
torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + 1
|
| 92 |
+
) # + 1 is hacking vev0 pt code
|
| 93 |
+
|
| 94 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 95 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 96 |
+
# freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
| 97 |
+
freqs = broadcat(
|
| 98 |
+
(freqs[None, :, :], freqs[:, None, :]), dim=-1
|
| 99 |
+
) # follow vev0 pt code
|
| 100 |
+
|
| 101 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 102 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 103 |
+
|
| 104 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 105 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 106 |
+
|
| 107 |
+
print("======== shape of rope freq", self.freqs_cos.shape, "========")
|
| 108 |
+
|
| 109 |
+
def forward(self, tt):
|
| 110 |
+
return tt * self.freqs_cos + rotate_half(tt) * self.freqs_sin
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LayerNorm(nn.LayerNorm):
|
| 114 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 115 |
+
|
| 116 |
+
def forward(self, x: torch.Tensor):
|
| 117 |
+
orig_type = x.dtype
|
| 118 |
+
# ret = super().forward(x.type(torch.float32))
|
| 119 |
+
ret = F.layer_norm(
|
| 120 |
+
x.type(torch.float32),
|
| 121 |
+
self.normalized_shape,
|
| 122 |
+
self.weight.type(torch.float32),
|
| 123 |
+
self.bias.type(torch.float32),
|
| 124 |
+
self.eps,
|
| 125 |
+
)
|
| 126 |
+
return ret.type(orig_type)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class QuickGELU(nn.Module):
|
| 130 |
+
def forward(self, x: torch.Tensor):
|
| 131 |
+
return x * torch.sigmoid(1.702 * x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def drop_path(
|
| 135 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
| 136 |
+
):
|
| 137 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 138 |
+
|
| 139 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 140 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 141 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 142 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 143 |
+
'survival rate' as the argument.
|
| 144 |
+
|
| 145 |
+
"""
|
| 146 |
+
if drop_prob == 0.0 or not training:
|
| 147 |
+
return x
|
| 148 |
+
keep_prob = 1 - drop_prob
|
| 149 |
+
shape = (x.shape[0],) + (1,) * (
|
| 150 |
+
x.ndim - 1
|
| 151 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 152 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 153 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 154 |
+
random_tensor.div_(keep_prob)
|
| 155 |
+
return x * random_tensor
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class DropPath(nn.Module):
|
| 159 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 160 |
+
|
| 161 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 162 |
+
super(DropPath, self).__init__()
|
| 163 |
+
self.drop_prob = drop_prob
|
| 164 |
+
self.scale_by_keep = scale_by_keep
|
| 165 |
+
|
| 166 |
+
def forward(self, x):
|
| 167 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 168 |
+
|
| 169 |
+
def extra_repr(self):
|
| 170 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Attention(nn.Module):
|
| 174 |
+
r"""
|
| 175 |
+
Implements attention based on Rope
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
embed_dim: int,
|
| 181 |
+
num_heads: int,
|
| 182 |
+
dropout: float = 0.0,
|
| 183 |
+
bias: bool = True,
|
| 184 |
+
add_bias_kv: bool = False,
|
| 185 |
+
kdim: Optional[bool] = None,
|
| 186 |
+
vdim: Optional[bool] = None,
|
| 187 |
+
rope=None,
|
| 188 |
+
):
|
| 189 |
+
super(Attention, self).__init__()
|
| 190 |
+
self.embed_dim = embed_dim
|
| 191 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 192 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 193 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 194 |
+
|
| 195 |
+
self.num_heads = num_heads
|
| 196 |
+
self.dropout = dropout
|
| 197 |
+
self.head_dim = embed_dim // num_heads
|
| 198 |
+
assert (
|
| 199 |
+
self.head_dim * num_heads == self.embed_dim
|
| 200 |
+
), "embed_dim must be divisible by num_heads"
|
| 201 |
+
|
| 202 |
+
if self._qkv_same_embed_dim is False:
|
| 203 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 204 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 205 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 206 |
+
else:
|
| 207 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 208 |
+
|
| 209 |
+
if bias:
|
| 210 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 211 |
+
else:
|
| 212 |
+
self.register_parameter("in_proj_bias", None)
|
| 213 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 214 |
+
|
| 215 |
+
if add_bias_kv:
|
| 216 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
| 217 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
| 218 |
+
else:
|
| 219 |
+
self.bias_k = self.bias_v = None
|
| 220 |
+
|
| 221 |
+
self.rope = rope
|
| 222 |
+
|
| 223 |
+
self.scale = self.head_dim ** (-0.5)
|
| 224 |
+
|
| 225 |
+
def forward(self, query, attn_mask: Optional[torch.Tensor] = None):
|
| 226 |
+
batch, seq, embed_dim = query.shape
|
| 227 |
+
|
| 228 |
+
proj = torch._C._nn.linear(query, self.in_proj_weight, self.in_proj_bias)
|
| 229 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
| 230 |
+
proj = (
|
| 231 |
+
proj.unflatten(-1, (3, embed_dim))
|
| 232 |
+
.unsqueeze(0)
|
| 233 |
+
.transpose(0, -2)
|
| 234 |
+
.squeeze(-2)
|
| 235 |
+
.contiguous()
|
| 236 |
+
)
|
| 237 |
+
q_, k_, v_ = proj[0], proj[1], proj[2]
|
| 238 |
+
|
| 239 |
+
# Use "q_" so that we don't accidentally quit in pdb :)
|
| 240 |
+
q_ = rearrange(q_, "b s (h d) -> b h s d", h=self.num_heads)
|
| 241 |
+
k_ = rearrange(k_, "b s (h d) -> b h s d", h=self.num_heads)
|
| 242 |
+
v_ = rearrange(v_, "b s (h d) -> b h s d", h=self.num_heads)
|
| 243 |
+
|
| 244 |
+
## rope
|
| 245 |
+
q_ = self.rope(q_).type_as(v_)
|
| 246 |
+
k_ = self.rope(k_).type_as(v_)
|
| 247 |
+
|
| 248 |
+
attn = (q_ * self.scale) @ k_.transpose(-2, -1)
|
| 249 |
+
attn = attn.softmax(dim=-1)
|
| 250 |
+
x_ = attn @ v_
|
| 251 |
+
|
| 252 |
+
x_ = rearrange(x_, "b h s d -> b s (h d)")
|
| 253 |
+
|
| 254 |
+
return torch._C._nn.linear(x_, self.out_proj.weight, self.out_proj.bias)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class LayerScale(nn.Module):
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
dim: int,
|
| 261 |
+
init_values: float = 1e-5,
|
| 262 |
+
inplace: bool = False,
|
| 263 |
+
) -> None:
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.inplace = inplace
|
| 266 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 267 |
+
|
| 268 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 269 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class ResidualAttentionBlock(nn.Module):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
d_model: int,
|
| 276 |
+
n_head: int,
|
| 277 |
+
mlp_ratio=4.0,
|
| 278 |
+
act_layer=nn.GELU,
|
| 279 |
+
norm_layer=LayerNorm,
|
| 280 |
+
drop_path=0.0,
|
| 281 |
+
use_rel_pos=False,
|
| 282 |
+
rel_pos_zero_init=True,
|
| 283 |
+
window_size=0,
|
| 284 |
+
rope=None,
|
| 285 |
+
input_size=None,
|
| 286 |
+
attn_mask=None,
|
| 287 |
+
init_values=0.0,
|
| 288 |
+
):
|
| 289 |
+
super().__init__()
|
| 290 |
+
|
| 291 |
+
self.attn = Attention(embed_dim=d_model, num_heads=n_head, rope=rope)
|
| 292 |
+
self.ls_1 = (
|
| 293 |
+
LayerScale(d_model, init_values=init_values)
|
| 294 |
+
if init_values > 0.0
|
| 295 |
+
else nn.Identity()
|
| 296 |
+
)
|
| 297 |
+
self.ln_1 = LayerNorm(d_model)
|
| 298 |
+
self.mlp = nn.Sequential(
|
| 299 |
+
OrderedDict(
|
| 300 |
+
[
|
| 301 |
+
("c_fc", nn.Linear(d_model, int(d_model * mlp_ratio))),
|
| 302 |
+
("gelu", act_layer()),
|
| 303 |
+
("c_proj", nn.Linear(int(d_model * mlp_ratio), d_model)),
|
| 304 |
+
]
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 309 |
+
self.ln_2 = LayerNorm(d_model)
|
| 310 |
+
self.attn_mask = attn_mask
|
| 311 |
+
self.ls_2 = (
|
| 312 |
+
LayerScale(d_model, init_values=init_values)
|
| 313 |
+
if init_values > 0.0
|
| 314 |
+
else nn.Identity()
|
| 315 |
+
)
|
| 316 |
+
self.window_size = window_size
|
| 317 |
+
|
| 318 |
+
def attention_nhwc(self, x: torch.Tensor):
|
| 319 |
+
self.attn_mask = (
|
| 320 |
+
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
| 321 |
+
if self.attn_mask is not None
|
| 322 |
+
else None
|
| 323 |
+
)
|
| 324 |
+
B, H, W, _ = x.shape
|
| 325 |
+
x = x.reshape(B, H * W, -1)
|
| 326 |
+
x = self.attn(x, attn_mask=self.attn_mask)
|
| 327 |
+
x = x.reshape(B, H, W, -1)
|
| 328 |
+
return x
|
| 329 |
+
|
| 330 |
+
def forward(self, x: torch.Tensor):
|
| 331 |
+
shortcut = x
|
| 332 |
+
|
| 333 |
+
x = self.ln_1(x)
|
| 334 |
+
# Window partition
|
| 335 |
+
if self.window_size > 0:
|
| 336 |
+
H, W = x.shape[1], x.shape[2]
|
| 337 |
+
x, pad_hw = window_partition(x, self.window_size)
|
| 338 |
+
|
| 339 |
+
x = self.attention_nhwc(x)
|
| 340 |
+
# Reverse window partition
|
| 341 |
+
if self.window_size > 0:
|
| 342 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
| 343 |
+
|
| 344 |
+
x = shortcut + self.drop_path(self.ls_1(x))
|
| 345 |
+
x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class Transformer(nn.Module):
|
| 350 |
+
def __init__(
|
| 351 |
+
self,
|
| 352 |
+
embed_dim: int,
|
| 353 |
+
depth: int,
|
| 354 |
+
num_heads: int,
|
| 355 |
+
mlp_ratio=4.0,
|
| 356 |
+
act_layer=nn.GELU,
|
| 357 |
+
norm_layer=LayerNorm,
|
| 358 |
+
drop_path_rate=0.0,
|
| 359 |
+
use_rel_pos=False,
|
| 360 |
+
rel_pos_zero_init=True,
|
| 361 |
+
window_size=0,
|
| 362 |
+
window_block_indexes=(),
|
| 363 |
+
img_size=1024,
|
| 364 |
+
patch_size=16,
|
| 365 |
+
rope_win=None,
|
| 366 |
+
rope_glb=None,
|
| 367 |
+
use_act_checkpoint=False,
|
| 368 |
+
act_checkpoint_ratio=1.0,
|
| 369 |
+
attn_mask=None,
|
| 370 |
+
init_values=0.0,
|
| 371 |
+
return_layer=[-1],
|
| 372 |
+
):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.use_act_checkpoint = use_act_checkpoint
|
| 375 |
+
self.act_checkpoint_ratio = act_checkpoint_ratio
|
| 376 |
+
|
| 377 |
+
# stochastic depth decay rule
|
| 378 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 379 |
+
|
| 380 |
+
self.resblocks = nn.ModuleList()
|
| 381 |
+
for i in range(depth):
|
| 382 |
+
block = ResidualAttentionBlock(
|
| 383 |
+
embed_dim,
|
| 384 |
+
num_heads,
|
| 385 |
+
attn_mask=attn_mask,
|
| 386 |
+
drop_path=dpr[i],
|
| 387 |
+
mlp_ratio=mlp_ratio,
|
| 388 |
+
act_layer=act_layer,
|
| 389 |
+
norm_layer=norm_layer,
|
| 390 |
+
use_rel_pos=use_rel_pos,
|
| 391 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 392 |
+
window_size=window_size if i in window_block_indexes else 0,
|
| 393 |
+
rope=rope_win if i in window_block_indexes else rope_glb,
|
| 394 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 395 |
+
init_values=init_values,
|
| 396 |
+
)
|
| 397 |
+
self.resblocks.append(block)
|
| 398 |
+
|
| 399 |
+
self.return_layer = return_layer
|
| 400 |
+
|
| 401 |
+
def forward(self, x: torch.Tensor):
|
| 402 |
+
x_list = []
|
| 403 |
+
for idx, blk in enumerate(self.resblocks):
|
| 404 |
+
if (
|
| 405 |
+
self.use_act_checkpoint
|
| 406 |
+
and (idx / len(self.resblocks)) <= self.act_checkpoint_ratio
|
| 407 |
+
):
|
| 408 |
+
x = checkpoint(blk, x)
|
| 409 |
+
else:
|
| 410 |
+
x = blk(x)
|
| 411 |
+
|
| 412 |
+
if idx in self.return_layer or idx == len(self.resblocks) - 1:
|
| 413 |
+
x_list.append(x)
|
| 414 |
+
|
| 415 |
+
return x, x_list
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class PEv1_simpleFPN(nn.Module):
|
| 419 |
+
def __init__(
|
| 420 |
+
self,
|
| 421 |
+
img_size=1024,
|
| 422 |
+
patch_size=16,
|
| 423 |
+
in_chans=3,
|
| 424 |
+
embed_dim=768,
|
| 425 |
+
depth=12,
|
| 426 |
+
num_heads=12,
|
| 427 |
+
mlp_ratio=4.0,
|
| 428 |
+
qkv_bias=True,
|
| 429 |
+
drop_path_rate=0.0,
|
| 430 |
+
norm_layer=nn.LayerNorm,
|
| 431 |
+
act_layer=nn.GELU,
|
| 432 |
+
use_abs_pos=True,
|
| 433 |
+
use_rel_pos=False,
|
| 434 |
+
rel_pos_zero_init=True,
|
| 435 |
+
rope=True,
|
| 436 |
+
pt_hw_seq_len=16,
|
| 437 |
+
intp_freq=True,
|
| 438 |
+
window_size=0,
|
| 439 |
+
window_block_indexes=(),
|
| 440 |
+
residual_block_indexes=(),
|
| 441 |
+
use_act_checkpoint=False,
|
| 442 |
+
act_checkpoint_ratio=1.0,
|
| 443 |
+
pretrain_img_size=336,
|
| 444 |
+
pretrain_use_cls_token=True,
|
| 445 |
+
out_feature="last_feat",
|
| 446 |
+
tile_posemb=False,
|
| 447 |
+
init_values=0.0,
|
| 448 |
+
tta_rope=False,
|
| 449 |
+
return_layer=[-1],
|
| 450 |
+
):
|
| 451 |
+
super().__init__()
|
| 452 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
| 453 |
+
|
| 454 |
+
self.conv1 = nn.Conv2d(
|
| 455 |
+
in_channels=in_chans,
|
| 456 |
+
out_channels=embed_dim,
|
| 457 |
+
kernel_size=patch_size,
|
| 458 |
+
stride=patch_size,
|
| 459 |
+
bias=False,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
if use_abs_pos:
|
| 463 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 464 |
+
num_patches = (pretrain_img_size // patch_size) * (
|
| 465 |
+
pretrain_img_size // patch_size
|
| 466 |
+
)
|
| 467 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
| 468 |
+
self.positional_embedding = nn.Parameter(
|
| 469 |
+
torch.zeros(1, num_positions, embed_dim)
|
| 470 |
+
)
|
| 471 |
+
print("positional_embedding:", self.positional_embedding.shape)
|
| 472 |
+
print("positional_embedding:", self.positional_embedding.shape)
|
| 473 |
+
print("positional_embedding:", self.positional_embedding.shape)
|
| 474 |
+
|
| 475 |
+
else:
|
| 476 |
+
self.positional_embedding = None
|
| 477 |
+
|
| 478 |
+
self.tile_posemb = tile_posemb
|
| 479 |
+
|
| 480 |
+
self.ln_pre = LayerNorm(embed_dim)
|
| 481 |
+
|
| 482 |
+
half_head_dim = embed_dim // num_heads // 2
|
| 483 |
+
hw_seq_len = img_size // patch_size
|
| 484 |
+
|
| 485 |
+
self.rope_win = VisionRotaryEmbeddingFast(
|
| 486 |
+
dim=half_head_dim,
|
| 487 |
+
pt_seq_len=pt_hw_seq_len,
|
| 488 |
+
ft_seq_len=window_size if intp_freq else None,
|
| 489 |
+
)
|
| 490 |
+
self.rope_glb = VisionRotaryEmbeddingFast(
|
| 491 |
+
dim=half_head_dim,
|
| 492 |
+
pt_seq_len=pt_hw_seq_len,
|
| 493 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
self.transformer = Transformer(
|
| 497 |
+
embed_dim=embed_dim,
|
| 498 |
+
depth=depth,
|
| 499 |
+
num_heads=num_heads,
|
| 500 |
+
mlp_ratio=mlp_ratio,
|
| 501 |
+
act_layer=act_layer,
|
| 502 |
+
norm_layer=norm_layer,
|
| 503 |
+
drop_path_rate=drop_path_rate,
|
| 504 |
+
use_rel_pos=use_rel_pos,
|
| 505 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 506 |
+
window_size=window_size,
|
| 507 |
+
window_block_indexes=window_block_indexes,
|
| 508 |
+
rope_win=self.rope_win,
|
| 509 |
+
rope_glb=self.rope_glb,
|
| 510 |
+
img_size=img_size,
|
| 511 |
+
patch_size=patch_size,
|
| 512 |
+
use_act_checkpoint=use_act_checkpoint,
|
| 513 |
+
act_checkpoint_ratio=act_checkpoint_ratio,
|
| 514 |
+
init_values=init_values,
|
| 515 |
+
return_layer=return_layer,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
self._out_feature_channels = {out_feature: embed_dim}
|
| 519 |
+
self._out_feature_strides = {out_feature: patch_size}
|
| 520 |
+
self._out_features = [out_feature]
|
| 521 |
+
|
| 522 |
+
if self.positional_embedding is not None:
|
| 523 |
+
nn.init.trunc_normal_(self.positional_embedding, std=0.02)
|
| 524 |
+
|
| 525 |
+
self.return_layer = return_layer
|
| 526 |
+
# In our method, we don't use backbone feature with stride 4
|
| 527 |
+
self.fpn1 = nn.Sequential(
|
| 528 |
+
nn.ConvTranspose2d(embed_dim, embed_dim // 2, kernel_size=2, stride=2),
|
| 529 |
+
)
|
| 530 |
+
self.fpn2 = nn.Identity()
|
| 531 |
+
self.fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 532 |
+
|
| 533 |
+
self.apply(self._init_weights)
|
| 534 |
+
|
| 535 |
+
strides = [patch_size // 2, patch_size, patch_size * 2]
|
| 536 |
+
self._out_features = ["p{}".format(int(math.log2(s))) for s in strides]
|
| 537 |
+
self._out_feature_strides = {
|
| 538 |
+
"p3": 8,
|
| 539 |
+
"p4": 16,
|
| 540 |
+
"p5": 32,
|
| 541 |
+
}
|
| 542 |
+
self._out_feature_channels = {
|
| 543 |
+
"p3": embed_dim // 2,
|
| 544 |
+
"p4": embed_dim,
|
| 545 |
+
"p5": embed_dim,
|
| 546 |
+
}
|
| 547 |
+
self._size_divisibility = strides[-1]
|
| 548 |
+
self._square_pad = img_size
|
| 549 |
+
|
| 550 |
+
def _init_weights(self, m):
|
| 551 |
+
if isinstance(m, nn.Linear):
|
| 552 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 553 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 554 |
+
nn.init.constant_(m.bias, 0)
|
| 555 |
+
elif isinstance(m, nn.LayerNorm):
|
| 556 |
+
nn.init.constant_(m.bias, 0)
|
| 557 |
+
nn.init.constant_(m.weight, 1.0)
|
| 558 |
+
|
| 559 |
+
def forward(self, x):
|
| 560 |
+
x = self.conv1(x)
|
| 561 |
+
x = x.permute(0, 2, 3, 1)
|
| 562 |
+
|
| 563 |
+
if self.positional_embedding is not None:
|
| 564 |
+
x = x + get_abs_pos(
|
| 565 |
+
self.positional_embedding,
|
| 566 |
+
self.pretrain_use_cls_token,
|
| 567 |
+
(x.shape[1], x.shape[2]),
|
| 568 |
+
self.tile_posemb,
|
| 569 |
+
)
|
| 570 |
+
x = self.ln_pre(x)
|
| 571 |
+
|
| 572 |
+
x, x_list = self.transformer(x)
|
| 573 |
+
|
| 574 |
+
xp = x.permute(0, 3, 1, 2) # (b, h, w, c) --> (b, c, h, w)
|
| 575 |
+
|
| 576 |
+
features = []
|
| 577 |
+
ops = [self.fpn1, self.fpn2, self.fpn3]
|
| 578 |
+
for i in range(len(ops)):
|
| 579 |
+
features.append(ops[i](xp))
|
| 580 |
+
rets = {"p{}".format(u + 3): v for (u, v) in enumerate(features)}
|
| 581 |
+
|
| 582 |
+
return rets
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def get_pev1_and_fpn_backbone(args):
|
| 586 |
+
if args.lsj_img_size_max > 0:
|
| 587 |
+
img_size = args.lsj_img_size_max
|
| 588 |
+
else:
|
| 589 |
+
img_size = args.lsj_img_size
|
| 590 |
+
use_act_checkpoint = args.backbone_use_act_checkpoint
|
| 591 |
+
act_checkpoint_ratio = args.backbone_act_checkpoint_ratio
|
| 592 |
+
init_values = args.backbone_init_values
|
| 593 |
+
tile_posemb = args.backbone_tile_posemb
|
| 594 |
+
tta_rope = args.backbone_tta_rope
|
| 595 |
+
multi_layer = args.backbone_multi_layer
|
| 596 |
+
backbone_dp = args.backbone_dp
|
| 597 |
+
|
| 598 |
+
if args.backbone_size == "G":
|
| 599 |
+
embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
|
| 600 |
+
pretrain_img_size, patch_size, window_size = 224, 16, 14
|
| 601 |
+
window_block_indexes = (
|
| 602 |
+
list(range(0, 12))
|
| 603 |
+
+ list(range(13, 24))
|
| 604 |
+
+ list(range(25, 36))
|
| 605 |
+
+ list(range(37, 49))
|
| 606 |
+
)
|
| 607 |
+
pretrain_use_cls_token = False
|
| 608 |
+
if multi_layer:
|
| 609 |
+
return_layer = [12, 24, 36, 49]
|
| 610 |
+
else:
|
| 611 |
+
return_layer = [-1]
|
| 612 |
+
|
| 613 |
+
elif args.backbone_size == "Gwin384":
|
| 614 |
+
embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
|
| 615 |
+
pretrain_img_size, patch_size, window_size = 384, 16, 24
|
| 616 |
+
window_block_indexes = (
|
| 617 |
+
list(range(0, 12))
|
| 618 |
+
+ list(range(13, 24))
|
| 619 |
+
+ list(range(25, 36))
|
| 620 |
+
+ list(range(37, 49))
|
| 621 |
+
)
|
| 622 |
+
pretrain_use_cls_token = False
|
| 623 |
+
if multi_layer:
|
| 624 |
+
return_layer = [12, 24, 36, 49]
|
| 625 |
+
else:
|
| 626 |
+
return_layer = [-1]
|
| 627 |
+
|
| 628 |
+
elif args.backbone_size == "Gwin512":
|
| 629 |
+
embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
|
| 630 |
+
pretrain_img_size, patch_size, window_size = 512, 16, 32
|
| 631 |
+
window_block_indexes = (
|
| 632 |
+
list(range(0, 12))
|
| 633 |
+
+ list(range(13, 24))
|
| 634 |
+
+ list(range(25, 36))
|
| 635 |
+
+ list(range(37, 49))
|
| 636 |
+
)
|
| 637 |
+
pretrain_use_cls_token = False
|
| 638 |
+
if multi_layer:
|
| 639 |
+
return_layer = [12, 24, 36, 49]
|
| 640 |
+
else:
|
| 641 |
+
return_layer = [-1]
|
| 642 |
+
else:
|
| 643 |
+
raise ValueError("Unsupported backbone size")
|
| 644 |
+
|
| 645 |
+
if backbone_dp >= 0:
|
| 646 |
+
dp = backbone_dp
|
| 647 |
+
|
| 648 |
+
assert (
|
| 649 |
+
depth == args.backbone_layers
|
| 650 |
+
), f"backbone depth {depth} and layers {args.backbone_layers}(from config) must be the same"
|
| 651 |
+
|
| 652 |
+
model = PEv1_simpleFPN(
|
| 653 |
+
use_act_checkpoint=use_act_checkpoint,
|
| 654 |
+
act_checkpoint_ratio=act_checkpoint_ratio,
|
| 655 |
+
pretrain_img_size=pretrain_img_size,
|
| 656 |
+
pretrain_use_cls_token=pretrain_use_cls_token,
|
| 657 |
+
img_size=img_size,
|
| 658 |
+
patch_size=patch_size,
|
| 659 |
+
embed_dim=embed_dim,
|
| 660 |
+
depth=depth,
|
| 661 |
+
num_heads=num_heads,
|
| 662 |
+
drop_path_rate=dp,
|
| 663 |
+
window_size=window_size,
|
| 664 |
+
pt_hw_seq_len=16, # Maybe a bug ?
|
| 665 |
+
mlp_ratio=mlp_ratio,
|
| 666 |
+
qkv_bias=True,
|
| 667 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 668 |
+
window_block_indexes=window_block_indexes,
|
| 669 |
+
residual_block_indexes=[],
|
| 670 |
+
use_rel_pos=True,
|
| 671 |
+
out_feature="last_feat",
|
| 672 |
+
tile_posemb=tile_posemb,
|
| 673 |
+
init_values=init_values,
|
| 674 |
+
tta_rope=tta_rope,
|
| 675 |
+
return_layer=return_layer,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
pretrained_backbone_path = args.backbone_path
|
| 679 |
+
if pretrained_backbone_path:
|
| 680 |
+
state_dict = torch.load(pretrained_backbone_path, map_location="cpu")
|
| 681 |
+
load_info = model.load_state_dict(state_dict["model"], strict=False)
|
| 682 |
+
print("Missing keys", load_info.missing_keys)
|
| 683 |
+
print("Unexpected keys", load_info.unexpected_keys)
|
| 684 |
+
else:
|
| 685 |
+
print("Skip pretrained backbone loading")
|
| 686 |
+
return model
|
perception_models/apps/detection/DETA_pe/models/position_encoding.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Various positional encodings for the transformer.
|
| 12 |
+
"""
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from util.misc import NestedTensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PositionEmbeddingSine(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 23 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.num_pos_feats = num_pos_feats
|
| 28 |
+
self.temperature = temperature
|
| 29 |
+
self.normalize = normalize
|
| 30 |
+
if scale is not None and normalize is False:
|
| 31 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 32 |
+
if scale is None:
|
| 33 |
+
scale = 2 * math.pi
|
| 34 |
+
self.scale = scale
|
| 35 |
+
|
| 36 |
+
def forward(self, tensor_list: NestedTensor):
|
| 37 |
+
x = tensor_list.tensors
|
| 38 |
+
mask = tensor_list.mask
|
| 39 |
+
assert mask is not None
|
| 40 |
+
not_mask = ~mask
|
| 41 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 42 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 43 |
+
if self.normalize:
|
| 44 |
+
eps = 1e-6
|
| 45 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
| 46 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
| 47 |
+
|
| 48 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 49 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 50 |
+
|
| 51 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 52 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 53 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 54 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 55 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 56 |
+
return pos
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class PositionEmbeddingLearned(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Absolute pos embedding, learned.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, num_pos_feats=256):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
| 66 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
| 67 |
+
self.reset_parameters()
|
| 68 |
+
|
| 69 |
+
def reset_parameters(self):
|
| 70 |
+
nn.init.uniform_(self.row_embed.weight)
|
| 71 |
+
nn.init.uniform_(self.col_embed.weight)
|
| 72 |
+
|
| 73 |
+
def forward(self, tensor_list: NestedTensor):
|
| 74 |
+
x = tensor_list.tensors
|
| 75 |
+
h, w = x.shape[-2:]
|
| 76 |
+
i = torch.arange(w, device=x.device)
|
| 77 |
+
j = torch.arange(h, device=x.device)
|
| 78 |
+
x_emb = self.col_embed(i)
|
| 79 |
+
y_emb = self.row_embed(j)
|
| 80 |
+
pos = torch.cat([
|
| 81 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
| 82 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
| 83 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
| 84 |
+
return pos
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_position_encoding(args):
|
| 88 |
+
N_steps = args.hidden_dim // 2
|
| 89 |
+
if args.position_embedding in ('v2', 'sine'):
|
| 90 |
+
# TODO find a better way of exposing other arguments
|
| 91 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
| 92 |
+
elif args.position_embedding in ('v3', 'learned'):
|
| 93 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
| 96 |
+
|
| 97 |
+
return position_embedding
|
perception_models/apps/detection/DETA_pe/models/segmentation.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
This file provides the definition of the convolutional heads used to predict masks, as well as the losses
|
| 12 |
+
"""
|
| 13 |
+
import io
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
import util.box_ops as box_ops
|
| 22 |
+
from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from panopticapi.utils import id2rgb, rgb2id
|
| 26 |
+
except ImportError:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DETRsegm(nn.Module):
|
| 31 |
+
def __init__(self, detr, freeze_detr=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.detr = detr
|
| 34 |
+
|
| 35 |
+
if freeze_detr:
|
| 36 |
+
for p in self.parameters():
|
| 37 |
+
p.requires_grad_(False)
|
| 38 |
+
|
| 39 |
+
hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
|
| 40 |
+
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0)
|
| 41 |
+
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
|
| 42 |
+
|
| 43 |
+
def forward(self, samples: NestedTensor):
|
| 44 |
+
if not isinstance(samples, NestedTensor):
|
| 45 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 46 |
+
features, pos = self.detr.backbone(samples)
|
| 47 |
+
|
| 48 |
+
bs = features[-1].tensors.shape[0]
|
| 49 |
+
|
| 50 |
+
src, mask = features[-1].decompose()
|
| 51 |
+
src_proj = self.detr.input_proj(src)
|
| 52 |
+
hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])
|
| 53 |
+
|
| 54 |
+
outputs_class = self.detr.class_embed(hs)
|
| 55 |
+
outputs_coord = self.detr.bbox_embed(hs).sigmoid()
|
| 56 |
+
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
|
| 57 |
+
if self.detr.aux_loss:
|
| 58 |
+
out["aux_outputs"] = [
|
| 59 |
+
{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
# FIXME h_boxes takes the last one computed, keep this in mind
|
| 63 |
+
bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
|
| 64 |
+
|
| 65 |
+
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
|
| 66 |
+
outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
| 67 |
+
|
| 68 |
+
out["pred_masks"] = outputs_seg_masks
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MaskHeadSmallConv(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Simple convolutional head, using group norm.
|
| 75 |
+
Upsampling is done using a FPN approach
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, dim, fpn_dims, context_dim):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
| 82 |
+
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
|
| 83 |
+
self.gn1 = torch.nn.GroupNorm(8, dim)
|
| 84 |
+
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
| 85 |
+
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
|
| 86 |
+
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
| 87 |
+
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
|
| 88 |
+
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
| 89 |
+
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
|
| 90 |
+
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
| 91 |
+
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
|
| 92 |
+
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
| 93 |
+
|
| 94 |
+
self.dim = dim
|
| 95 |
+
|
| 96 |
+
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
| 97 |
+
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
| 98 |
+
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
| 99 |
+
|
| 100 |
+
for m in self.modules():
|
| 101 |
+
if isinstance(m, nn.Conv2d):
|
| 102 |
+
nn.init.kaiming_uniform_(m.weight, a=1)
|
| 103 |
+
nn.init.constant_(m.bias, 0)
|
| 104 |
+
|
| 105 |
+
def forward(self, x, bbox_mask, fpns):
|
| 106 |
+
def expand(tensor, length):
|
| 107 |
+
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
| 108 |
+
|
| 109 |
+
x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
| 110 |
+
|
| 111 |
+
x = self.lay1(x)
|
| 112 |
+
x = self.gn1(x)
|
| 113 |
+
x = F.relu(x)
|
| 114 |
+
x = self.lay2(x)
|
| 115 |
+
x = self.gn2(x)
|
| 116 |
+
x = F.relu(x)
|
| 117 |
+
|
| 118 |
+
cur_fpn = self.adapter1(fpns[0])
|
| 119 |
+
if cur_fpn.size(0) != x.size(0):
|
| 120 |
+
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
| 121 |
+
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 122 |
+
x = self.lay3(x)
|
| 123 |
+
x = self.gn3(x)
|
| 124 |
+
x = F.relu(x)
|
| 125 |
+
|
| 126 |
+
cur_fpn = self.adapter2(fpns[1])
|
| 127 |
+
if cur_fpn.size(0) != x.size(0):
|
| 128 |
+
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
| 129 |
+
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 130 |
+
x = self.lay4(x)
|
| 131 |
+
x = self.gn4(x)
|
| 132 |
+
x = F.relu(x)
|
| 133 |
+
|
| 134 |
+
cur_fpn = self.adapter3(fpns[2])
|
| 135 |
+
if cur_fpn.size(0) != x.size(0):
|
| 136 |
+
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
| 137 |
+
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
| 138 |
+
x = self.lay5(x)
|
| 139 |
+
x = self.gn5(x)
|
| 140 |
+
x = F.relu(x)
|
| 141 |
+
|
| 142 |
+
x = self.out_lay(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MHAttentionMap(nn.Module):
|
| 147 |
+
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.num_heads = num_heads
|
| 152 |
+
self.hidden_dim = hidden_dim
|
| 153 |
+
self.dropout = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 156 |
+
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
| 157 |
+
|
| 158 |
+
nn.init.zeros_(self.k_linear.bias)
|
| 159 |
+
nn.init.zeros_(self.q_linear.bias)
|
| 160 |
+
nn.init.xavier_uniform_(self.k_linear.weight)
|
| 161 |
+
nn.init.xavier_uniform_(self.q_linear.weight)
|
| 162 |
+
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
| 163 |
+
|
| 164 |
+
def forward(self, q, k, mask=None):
|
| 165 |
+
q = self.q_linear(q)
|
| 166 |
+
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
| 167 |
+
qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
| 168 |
+
kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
| 169 |
+
weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
|
| 170 |
+
|
| 171 |
+
if mask is not None:
|
| 172 |
+
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
|
| 173 |
+
weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
|
| 174 |
+
weights = self.dropout(weights)
|
| 175 |
+
return weights
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def dice_loss(inputs, targets, num_boxes):
|
| 179 |
+
"""
|
| 180 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
| 181 |
+
Args:
|
| 182 |
+
inputs: A float tensor of arbitrary shape.
|
| 183 |
+
The predictions for each example.
|
| 184 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
| 185 |
+
classification label for each element in inputs
|
| 186 |
+
(0 for the negative class and 1 for the positive class).
|
| 187 |
+
"""
|
| 188 |
+
inputs = inputs.sigmoid()
|
| 189 |
+
inputs = inputs.flatten(1)
|
| 190 |
+
numerator = 2 * (inputs * targets).sum(1)
|
| 191 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
| 192 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
| 193 |
+
return loss.sum() / num_boxes
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
| 197 |
+
"""
|
| 198 |
+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
| 199 |
+
Args:
|
| 200 |
+
inputs: A float tensor of arbitrary shape.
|
| 201 |
+
The predictions for each example.
|
| 202 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
| 203 |
+
classification label for each element in inputs
|
| 204 |
+
(0 for the negative class and 1 for the positive class).
|
| 205 |
+
alpha: (optional) Weighting factor in range (0,1) to balance
|
| 206 |
+
positive vs negative examples. Default = -1 (no weighting).
|
| 207 |
+
gamma: Exponent of the modulating factor (1 - p_t) to
|
| 208 |
+
balance easy vs hard examples.
|
| 209 |
+
Returns:
|
| 210 |
+
Loss tensor
|
| 211 |
+
"""
|
| 212 |
+
prob = inputs.sigmoid()
|
| 213 |
+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
| 214 |
+
p_t = prob * targets + (1 - prob) * (1 - targets)
|
| 215 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
| 216 |
+
|
| 217 |
+
if alpha >= 0:
|
| 218 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
| 219 |
+
loss = alpha_t * loss
|
| 220 |
+
|
| 221 |
+
return loss.mean(1).sum() / num_boxes
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class PostProcessSegm(nn.Module):
|
| 225 |
+
def __init__(self, threshold=0.5):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.threshold = threshold
|
| 228 |
+
|
| 229 |
+
@torch.no_grad()
|
| 230 |
+
def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
|
| 231 |
+
assert len(orig_target_sizes) == len(max_target_sizes)
|
| 232 |
+
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
| 233 |
+
outputs_masks = outputs["pred_masks"].squeeze(2)
|
| 234 |
+
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
|
| 235 |
+
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
|
| 236 |
+
|
| 237 |
+
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
| 238 |
+
img_h, img_w = t[0], t[1]
|
| 239 |
+
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
| 240 |
+
results[i]["masks"] = F.interpolate(
|
| 241 |
+
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
| 242 |
+
).byte()
|
| 243 |
+
|
| 244 |
+
return results
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class PostProcessPanoptic(nn.Module):
|
| 248 |
+
"""This class converts the output of the model to the final panoptic result, in the format expected by the
|
| 249 |
+
coco panoptic API """
|
| 250 |
+
|
| 251 |
+
def __init__(self, is_thing_map, threshold=0.85):
|
| 252 |
+
"""
|
| 253 |
+
Parameters:
|
| 254 |
+
is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
|
| 255 |
+
the class is a thing (True) or a stuff (False) class
|
| 256 |
+
threshold: confidence threshold: segments with confidence lower than this will be deleted
|
| 257 |
+
"""
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.threshold = threshold
|
| 260 |
+
self.is_thing_map = is_thing_map
|
| 261 |
+
|
| 262 |
+
def forward(self, outputs, processed_sizes, target_sizes=None):
|
| 263 |
+
""" This function computes the panoptic prediction from the model's predictions.
|
| 264 |
+
Parameters:
|
| 265 |
+
outputs: This is a dict coming directly from the model. See the model doc for the content.
|
| 266 |
+
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
|
| 267 |
+
model, ie the size after data augmentation but before batching.
|
| 268 |
+
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
|
| 269 |
+
of each prediction. If left to None, it will default to the processed_sizes
|
| 270 |
+
"""
|
| 271 |
+
if target_sizes is None:
|
| 272 |
+
target_sizes = processed_sizes
|
| 273 |
+
assert len(processed_sizes) == len(target_sizes)
|
| 274 |
+
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
|
| 275 |
+
assert len(out_logits) == len(raw_masks) == len(target_sizes)
|
| 276 |
+
preds = []
|
| 277 |
+
|
| 278 |
+
def to_tuple(tup):
|
| 279 |
+
if isinstance(tup, tuple):
|
| 280 |
+
return tup
|
| 281 |
+
return tuple(tup.cpu().tolist())
|
| 282 |
+
|
| 283 |
+
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
| 284 |
+
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
| 285 |
+
):
|
| 286 |
+
# we filter empty queries and detection below threshold
|
| 287 |
+
scores, labels = cur_logits.softmax(-1).max(-1)
|
| 288 |
+
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
|
| 289 |
+
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
| 290 |
+
cur_scores = cur_scores[keep]
|
| 291 |
+
cur_classes = cur_classes[keep]
|
| 292 |
+
cur_masks = cur_masks[keep]
|
| 293 |
+
cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0)
|
| 294 |
+
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
|
| 295 |
+
|
| 296 |
+
h, w = cur_masks.shape[-2:]
|
| 297 |
+
assert len(cur_boxes) == len(cur_classes)
|
| 298 |
+
|
| 299 |
+
# It may be that we have several predicted masks for the same stuff class.
|
| 300 |
+
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
| 301 |
+
cur_masks = cur_masks.flatten(1)
|
| 302 |
+
stuff_equiv_classes = defaultdict(lambda: [])
|
| 303 |
+
for k, label in enumerate(cur_classes):
|
| 304 |
+
if not self.is_thing_map[label.item()]:
|
| 305 |
+
stuff_equiv_classes[label.item()].append(k)
|
| 306 |
+
|
| 307 |
+
def get_ids_area(masks, scores, dedup=False):
|
| 308 |
+
# This helper function creates the final panoptic segmentation image
|
| 309 |
+
# It also returns the area of the masks that appears on the image
|
| 310 |
+
|
| 311 |
+
m_id = masks.transpose(0, 1).softmax(-1)
|
| 312 |
+
|
| 313 |
+
if m_id.shape[-1] == 0:
|
| 314 |
+
# We didn't detect any mask :(
|
| 315 |
+
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
| 316 |
+
else:
|
| 317 |
+
m_id = m_id.argmax(-1).view(h, w)
|
| 318 |
+
|
| 319 |
+
if dedup:
|
| 320 |
+
# Merge the masks corresponding to the same stuff class
|
| 321 |
+
for equiv in stuff_equiv_classes.values():
|
| 322 |
+
if len(equiv) > 1:
|
| 323 |
+
for eq_id in equiv:
|
| 324 |
+
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
| 325 |
+
|
| 326 |
+
final_h, final_w = to_tuple(target_size)
|
| 327 |
+
|
| 328 |
+
seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
|
| 329 |
+
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
|
| 330 |
+
|
| 331 |
+
np_seg_img = (
|
| 332 |
+
torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
|
| 333 |
+
)
|
| 334 |
+
m_id = torch.from_numpy(rgb2id(np_seg_img))
|
| 335 |
+
|
| 336 |
+
area = []
|
| 337 |
+
for i in range(len(scores)):
|
| 338 |
+
area.append(m_id.eq(i).sum().item())
|
| 339 |
+
return area, seg_img
|
| 340 |
+
|
| 341 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
| 342 |
+
if cur_classes.numel() > 0:
|
| 343 |
+
# We know filter empty masks as long as we find some
|
| 344 |
+
while True:
|
| 345 |
+
filtered_small = torch.as_tensor(
|
| 346 |
+
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
|
| 347 |
+
)
|
| 348 |
+
if filtered_small.any().item():
|
| 349 |
+
cur_scores = cur_scores[~filtered_small]
|
| 350 |
+
cur_classes = cur_classes[~filtered_small]
|
| 351 |
+
cur_masks = cur_masks[~filtered_small]
|
| 352 |
+
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
| 353 |
+
else:
|
| 354 |
+
break
|
| 355 |
+
|
| 356 |
+
else:
|
| 357 |
+
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
|
| 358 |
+
|
| 359 |
+
segments_info = []
|
| 360 |
+
for i, a in enumerate(area):
|
| 361 |
+
cat = cur_classes[i].item()
|
| 362 |
+
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
|
| 363 |
+
del cur_classes
|
| 364 |
+
|
| 365 |
+
with io.BytesIO() as out:
|
| 366 |
+
seg_img.save(out, format="PNG")
|
| 367 |
+
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
| 368 |
+
preds.append(predictions)
|
| 369 |
+
return preds
|