diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..47b6957b885f951b1edf155fcc47476c9cbbc6cc 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,27 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
+examples/CNAuxBanner.jpg filter=lfs diff=lfs merge=lfs -text
+examples/ExecuteAll.png filter=lfs diff=lfs merge=lfs -text
+examples/ExecuteAll1.jpg filter=lfs diff=lfs merge=lfs -text
+examples/ExecuteAll2.jpg filter=lfs diff=lfs merge=lfs -text
+examples/comfyui-controlnet-aux-logo.png filter=lfs diff=lfs merge=lfs -text
+examples/example_animal_pose.png filter=lfs diff=lfs merge=lfs -text
+examples/example_anime_face_segmentor.png filter=lfs diff=lfs merge=lfs -text
+examples/example_anyline.png filter=lfs diff=lfs merge=lfs -text
+examples/example_densepose.png filter=lfs diff=lfs merge=lfs -text
+examples/example_depth_anything.png filter=lfs diff=lfs merge=lfs -text
+examples/example_depth_anything_v2.png filter=lfs diff=lfs merge=lfs -text
+examples/example_dsine.png filter=lfs diff=lfs merge=lfs -text
+examples/example_marigold.png filter=lfs diff=lfs merge=lfs -text
+examples/example_marigold_flat.jpg filter=lfs diff=lfs merge=lfs -text
+examples/example_mesh_graphormer.png filter=lfs diff=lfs merge=lfs -text
+examples/example_metric3d.png filter=lfs diff=lfs merge=lfs -text
+examples/example_recolor.png filter=lfs diff=lfs merge=lfs -text
+examples/example_save_kps.png filter=lfs diff=lfs merge=lfs -text
+examples/example_teed.png filter=lfs diff=lfs merge=lfs -text
+examples/example_torchscript.png filter=lfs diff=lfs merge=lfs -text
+examples/example_unimatch.png filter=lfs diff=lfs merge=lfs -text
+src/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
+tests/pose.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..eb5b7d5f5484abc954a89dde2b97d327de44b43b
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,25 @@
+name: Publish to Comfy registry
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - "pyproject.toml"
+
+permissions:
+ issues: write
+
+jobs:
+ publish-node:
+ name: Publish Custom Node to registry
+ runs-on: ubuntu-latest
+ if: ${{ github.repository_owner == 'Fannovel16' }}
+ steps:
+ - name: Check out code
+ uses: actions/checkout@v4
+ - name: Publish Custom Node
+ uses: Comfy-Org/publish-node-action@v1
+ with:
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c280aed1f84c22714c9bb068e94dc1a0b88a0626
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,183 @@
+# Initially taken from Github's Python gitignore file
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# tests and logs
+tests/fixtures/cached_*_text.txt
+logs/
+lightning_logs/
+lang_code_data/
+tests/outputs
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# vscode
+.vs
+.vscode
+
+# Pycharm
+.idea
+
+# TF code
+tensorflow_code
+
+# Models
+proc_data
+
+# examples
+runs
+/runs_old
+/wandb
+/examples/runs
+/examples/**/*.args
+/examples/rag/sweep
+
+# data
+/data
+serialization_dir
+
+# emacs
+*.*~
+debug.env
+
+# vim
+.*.swp
+
+#ctags
+tags
+
+# pre-commit
+.pre-commit*
+
+# .lock
+*.lock
+
+# DS_Store (MacOS)
+.DS_Store
+# RL pipelines may produce mp4 outputs
+*.mp4
+
+# dependencies
+/transformers
+
+# ruff
+.ruff_cache
+
+wandb
+
+ckpts/
+
+test.ipynb
+config.yaml
+test.ipynb
\ No newline at end of file
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/NotoSans-Regular.ttf b/NotoSans-Regular.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..4df6c6c6d0a0aa7c81f7bacfe4240bdd968d2d65
--- /dev/null
+++ b/NotoSans-Regular.ttf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b04c8dd65af6b73eb4279472ed1580b29102d6496a377340e80a40cdb3b22c9
+size 455188
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1c0e2dcc38eac48b290b4e323920e8b1477d4b8b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,252 @@
+# ComfyUI's ControlNet Auxiliary Preprocessors
+Plug-and-play [ComfyUI](https://github.com/comfyanonymous/ComfyUI) node sets for making [ControlNet](https://github.com/lllyasviel/ControlNet/) hint images
+
+"anime style, a protest in the street, cyberpunk city, a woman with pink hair and golden eyes (looking at the viewer) is holding a sign with the text "ComfyUI ControlNet Aux" in bold, neon pink" on Flux.1 Dev
+
+
+
+The code is copy-pasted from the respective folders in https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to [the 🤗 Hub](https://huggingface.co/lllyasviel/Annotators).
+
+All credit & copyright goes to https://github.com/lllyasviel.
+
+# Updates
+Go to [Update page](./UPDATES.md) to follow updates
+
+# Installation:
+## Using ComfyUI Manager (recommended):
+Install [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) and do steps introduced there to install this repo.
+
+## Alternative:
+If you're running on Linux, or non-admin account on windows you'll want to ensure `/ComfyUI/custom_nodes` and `comfyui_controlnet_aux` has write permissions.
+
+There is now a **install.bat** you can run to install to portable if detected. Otherwise it will default to system and assume you followed ConfyUI's manual installation steps.
+
+If you can't run **install.bat** (e.g. you are a Linux user). Open the CMD/Shell and do the following:
+ - Navigate to your `/ComfyUI/custom_nodes/` folder
+ - Run `git clone https://github.com/Fannovel16/comfyui_controlnet_aux/`
+ - Navigate to your `comfyui_controlnet_aux` folder
+ - Portable/venv:
+ - Run `path/to/ComfUI/python_embeded/python.exe -s -m pip install -r requirements.txt`
+ - With system python
+ - Run `pip install -r requirements.txt`
+ - Start ComfyUI
+
+# Nodes
+Please note that this repo only supports preprocessors making hint images (e.g. stickman, canny edge, etc).
+All preprocessors except Inpaint are intergrated into `AIO Aux Preprocessor` node.
+This node allow you to quickly get the preprocessor but a preprocessor's own threshold parameters won't be able to set.
+You need to use its node directly to set thresholds.
+
+# Nodes (sections are categories in Comfy menu)
+## Line Extractors
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| Binary Lines | binary | control_scribble |
+| Canny Edge | canny | control_v11p_sd15_canny
control_canny
t2iadapter_canny |
+| HED Soft-Edge Lines | hed | control_v11p_sd15_softedge
control_hed |
+| Standard Lineart | standard_lineart | control_v11p_sd15_lineart |
+| Realistic Lineart | lineart (or `lineart_coarse` if `coarse` is enabled) | control_v11p_sd15_lineart |
+| Anime Lineart | lineart_anime | control_v11p_sd15s2_lineart_anime |
+| Manga Lineart | lineart_anime_denoise | control_v11p_sd15s2_lineart_anime |
+| M-LSD Lines | mlsd | control_v11p_sd15_mlsd
control_mlsd |
+| PiDiNet Soft-Edge Lines | pidinet | control_v11p_sd15_softedge
control_scribble |
+| Scribble Lines | scribble | control_v11p_sd15_scribble
control_scribble |
+| Scribble XDoG Lines | scribble_xdog | control_v11p_sd15_scribble
control_scribble |
+| Fake Scribble Lines | scribble_hed | control_v11p_sd15_scribble
control_scribble |
+| TEED Soft-Edge Lines | teed | [controlnet-sd-xl-1.0-softedge-dexined](https://huggingface.co/SargeZT/controlnet-sd-xl-1.0-softedge-dexined/blob/main/controlnet-sd-xl-1.0-softedge-dexined.safetensors)
control_v11p_sd15_softedge (Theoretically)
+| Scribble PiDiNet Lines | scribble_pidinet | control_v11p_sd15_scribble
control_scribble |
+| AnyLine Lineart | | mistoLine_fp16.safetensors
mistoLine_rank256
control_v11p_sd15s2_lineart_anime
control_v11p_sd15_lineart |
+
+## Normal and Depth Estimators
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| MiDaS Depth Map | (normal) depth | control_v11f1p_sd15_depth
control_depth
t2iadapter_depth |
+| LeReS Depth Map | depth_leres | control_v11f1p_sd15_depth
control_depth
t2iadapter_depth |
+| Zoe Depth Map | depth_zoe | control_v11f1p_sd15_depth
control_depth
t2iadapter_depth |
+| MiDaS Normal Map | normal_map | control_normal |
+| BAE Normal Map | normal_bae | control_v11p_sd15_normalbae |
+| MeshGraphormer Hand Refiner ([HandRefinder](https://github.com/wenquanlu/HandRefiner)) | depth_hand_refiner | [control_sd15_inpaint_depth_hand_fp16](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/control_sd15_inpaint_depth_hand_fp16.safetensors) |
+| Depth Anything | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
+| Zoe Depth Anything
(Basically Zoe but the encoder is replaced with DepthAnything) | depth_anything | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
+| Normal DSINE | | control_normal/control_v11p_sd15_normalbae |
+| Metric3D Depth | | control_v11f1p_sd15_depth
control_depth
t2iadapter_depth |
+| Metric3D Normal | | control_v11p_sd15_normalbae |
+| Depth Anything V2 | | [Depth-Anything](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_controlnet/diffusion_pytorch_model.safetensors) |
+
+## Faces and Poses Estimators
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| DWPose Estimator | dw_openpose_full | control_v11p_sd15_openpose
control_openpose
t2iadapter_openpose |
+| OpenPose Estimator | openpose (detect_body)
openpose_hand (detect_body + detect_hand)
openpose_faceonly (detect_face)
openpose_full (detect_hand + detect_body + detect_face) | control_v11p_sd15_openpose
control_openpose
t2iadapter_openpose |
+| MediaPipe Face Mesh | mediapipe_face | controlnet_sd21_laion_face_v2 |
+| Animal Estimator | animal_openpose | [control_sd15_animal_openpose_fp16](https://huggingface.co/huchenlei/animal_openpose/blob/main/control_sd15_animal_openpose_fp16.pth) |
+
+## Optical Flow Estimators
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| Unimatch Optical Flow | | [DragNUWA](https://github.com/ProjectNUWA/DragNUWA) |
+
+### How to get OpenPose-format JSON?
+#### User-side
+This workflow will save images to ComfyUI's output folder (the same location as output images). If you haven't found `Save Pose Keypoints` node, update this extension
+
+
+#### Dev-side
+An array of [OpenPose-format JSON](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md#json-output-format) corresponsding to each frame in an IMAGE batch can be gotten from DWPose and OpenPose using `app.nodeOutputs` on the UI or `/history` API endpoint. JSON output from AnimalPose uses a kinda similar format to OpenPose JSON:
+```
+[
+ {
+ "version": "ap10k",
+ "animals": [
+ [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
+ [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
+ ...
+ ],
+ "canvas_height": 512,
+ "canvas_width": 768
+ },
+ ...
+]
+```
+
+For extension developers (e.g. Openpose editor):
+```js
+const poseNodes = app.graph._nodes.filter(node => ["OpenposePreprocessor", "DWPreprocessor", "AnimalPosePreprocessor"].includes(node.type))
+for (const poseNode of poseNodes) {
+ const openposeResults = JSON.parse(app.nodeOutputs[poseNode.id].openpose_json[0])
+ console.log(openposeResults) //An array containing Openpose JSON for each frame
+}
+```
+
+For API users:
+Javascript
+```js
+import fetch from "node-fetch" //Remember to add "type": "module" to "package.json"
+async function main() {
+ const promptId = '792c1905-ecfe-41f4-8114-83e6a4a09a9f' //Too lazy to POST /queue
+ let history = await fetch(`http://127.0.0.1:8188/history/${promptId}`).then(re => re.json())
+ history = history[promptId]
+ const nodeOutputs = Object.values(history.outputs).filter(output => output.openpose_json)
+ for (const nodeOutput of nodeOutputs) {
+ const openposeResults = JSON.parse(nodeOutput.openpose_json[0])
+ console.log(openposeResults) //An array containing Openpose JSON for each frame
+ }
+}
+main()
+```
+
+Python
+```py
+import json, urllib.request
+
+server_address = "127.0.0.1:8188"
+prompt_id = '' #Too lazy to POST /queue
+
+def get_history(prompt_id):
+ with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
+ return json.loads(response.read())
+
+history = get_history(prompt_id)[prompt_id]
+for o in history['outputs']:
+ for node_id in history['outputs']:
+ node_output = history['outputs'][node_id]
+ if 'openpose_json' in node_output:
+ print(json.loads(node_output['openpose_json'][0])) #An list containing Openpose JSON for each frame
+```
+## Semantic Segmentation
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| OneFormer ADE20K Segmentor | oneformer_ade20k | control_v11p_sd15_seg |
+| OneFormer COCO Segmentor | oneformer_coco | control_v11p_sd15_seg |
+| UniFormer Segmentor | segmentation |control_sd15_seg
control_v11p_sd15_seg|
+
+## T2IAdapter-only
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| Color Pallete | color | t2iadapter_color |
+| Content Shuffle | shuffle | t2iadapter_style |
+
+## Recolor
+| Preprocessor Node | sd-webui-controlnet/other | ControlNet/T2I-Adapter |
+|-----------------------------|---------------------------|-------------------------------------------|
+| Image Luminance | recolor_luminance | [ioclab_sd15_recolor](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/ioclab_sd15_recolor.safetensors)
[sai_xl_recolor_256lora](https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_recolor_256lora.safetensors)
[bdsqlsz_controlllite_xl_recolor_luminance](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/bdsqlsz_controlllite_xl_recolor_luminance.safetensors) |
+| Image Intensity | recolor_intensity | Idk. Maybe same as above? |
+
+# Examples
+> A picture is worth a thousand words
+
+
+
+
+# Testing workflow
+https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/ExecuteAll.png
+Input image: https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/examples/comfyui-controlnet-aux-logo.png
+
+# Q&A:
+## Why some nodes doesn't appear after I installed this repo?
+
+This repo has a new mechanism which will skip any custom node can't be imported. If you meet this case, please create a issue on [Issues tab](https://github.com/Fannovel16/comfyui_controlnet_aux/issues) with the log from the command line.
+
+## DWPose/AnimalPose only uses CPU so it's so slow. How can I make it use GPU?
+There are two ways to speed-up DWPose: using TorchScript checkpoints (.torchscript.pt) checkpoints or ONNXRuntime (.onnx). TorchScript way is little bit slower than ONNXRuntime but doesn't require any additional library and still way way faster than CPU.
+
+A torchscript bbox detector is compatiable with an onnx pose estimator and vice versa.
+### TorchScript
+Set `bbox_detector` and `pose_estimator` according to this picture. You can try other bbox detector endings with `.torchscript.pt` to reduce bbox detection time if input images are ideal.
+
+### ONNXRuntime
+If onnxruntime is installed successfully and the checkpoint used endings with `.onnx`, it will replace default cv2 backend to take advantage of GPU. Note that if you are using NVidia card, this method currently can only works on CUDA 11.8 (ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z) unless you compile onnxruntime yourself.
+
+1. Know your onnxruntime build:
+* * NVidia CUDA 11.x or bellow/AMD GPU: `onnxruntime-gpu`
+* * NVidia CUDA 12.x: `onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`
+* * DirectML: `onnxruntime-directml`
+* * OpenVINO: `onnxruntime-openvino`
+
+Note that if this is your first time using ComfyUI, please test if it can run on your device before doing next steps.
+
+2. Add it into `requirements.txt`
+
+3. Run `install.bat` or pip command mentioned in Installation
+
+
+
+# Assets files of preprocessors
+* anime_face_segment: [bdsqlsz/qinglong_controlnet-lllite/Annotators/UNet.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/UNet.pth), [anime-seg/isnetis.ckpt](https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
+* densepose: [LayerNorm/DensePose-TorchScript-with-hint-image/densepose_r50_fpn_dl.torchscript](https://huggingface.co/LayerNorm/DensePose-TorchScript-with-hint-image/blob/main/densepose_r50_fpn_dl.torchscript)
+* dwpose:
+* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
+* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/dw-ll_ucoco_384_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/dw-ll_ucoco_384_bs5.torchscript.pt), [yzd-v/DWPose/dw-ll_ucoco_384.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx)
+* animal_pose (ap10k):
+* * bbox_detector: Either [yzd-v/DWPose/yolox_l.onnx](https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx), [hr16/yolox-onnx/yolox_l.torchscript.pt](https://huggingface.co/hr16/yolox-onnx/blob/main/yolox_l.torchscript.pt), [hr16/yolo-nas-fp16/yolo_nas_l_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_l_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_m_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_m_fp16.onnx), [hr16/yolo-nas-fp16/yolo_nas_s_fp16.onnx](https://huggingface.co/hr16/yolo-nas-fp16/blob/main/yolo_nas_s_fp16.onnx)
+* * pose_estimator: Either [hr16/DWPose-TorchScript-BatchSize5/rtmpose-m_ap10k_256_bs5.torchscript.pt](https://huggingface.co/hr16/DWPose-TorchScript-BatchSize5/blob/main/rtmpose-m_ap10k_256_bs5.torchscript.pt), [hr16/UnJIT-DWPose/rtmpose-m_ap10k_256.onnx](https://huggingface.co/hr16/UnJIT-DWPose/blob/main/rtmpose-m_ap10k_256.onnx)
+* hed: [lllyasviel/Annotators/ControlNetHED.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth)
+* leres: [lllyasviel/Annotators/res101.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/res101.pth), [lllyasviel/Annotators/latest_net_G.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/latest_net_G.pth)
+* lineart: [lllyasviel/Annotators/sk_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model.pth), [lllyasviel/Annotators/sk_model2.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/sk_model2.pth)
+* lineart_anime: [lllyasviel/Annotators/netG.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/netG.pth)
+* manga_line: [lllyasviel/Annotators/erika.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/erika.pth)
+* mesh_graphormer: [hr16/ControlNet-HandRefiner-pruned/graphormer_hand_state_dict.bin](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/graphormer_hand_state_dict.bin), [hr16/ControlNet-HandRefiner-pruned/hrnetv2_w64_imagenet_pretrained.pth](https://huggingface.co/hr16/ControlNet-HandRefiner-pruned/blob/main/hrnetv2_w64_imagenet_pretrained.pth)
+* midas: [lllyasviel/Annotators/dpt_hybrid-midas-501f0c75.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/dpt_hybrid-midas-501f0c75.pt)
+* mlsd: [lllyasviel/Annotators/mlsd_large_512_fp32.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/mlsd_large_512_fp32.pth)
+* normalbae: [lllyasviel/Annotators/scannet.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/scannet.pt)
+* oneformer: [lllyasviel/Annotators/250_16_swin_l_oneformer_ade20k_160k.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/250_16_swin_l_oneformer_ade20k_160k.pth)
+* open_pose: [lllyasviel/Annotators/body_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/body_pose_model.pth), [lllyasviel/Annotators/hand_pose_model.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/hand_pose_model.pth), [lllyasviel/Annotators/facenet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/facenet.pth)
+* pidi: [lllyasviel/Annotators/table5_pidinet.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/table5_pidinet.pth)
+* sam: [dhkim2810/MobileSAM/mobile_sam.pt](https://huggingface.co/dhkim2810/MobileSAM/blob/main/mobile_sam.pt)
+* uniformer: [lllyasviel/Annotators/upernet_global_small.pth](https://huggingface.co/lllyasviel/Annotators/blob/main/upernet_global_small.pth)
+* zoe: [lllyasviel/Annotators/ZoeD_M12_N.pt](https://huggingface.co/lllyasviel/Annotators/blob/main/ZoeD_M12_N.pt)
+* teed: [bdsqlsz/qinglong_controlnet-lllite/7_model.pth](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/blob/main/Annotators/7_model.pth)
+* depth_anything: Either [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitl14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitl14.pth), [LiheYoung/Depth-Anything/checkpoints/depth_anything_vitb14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitb14.pth) or [LiheYoung/Depth-Anything/checkpoints/depth_anything_vits14.pth](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vits14.pth)
+* diffusion_edge: Either [hr16/Diffusion-Edge/diffusion_edge_indoor.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_indoor.pt), [hr16/Diffusion-Edge/diffusion_edge_urban.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_urban.pt) or [hr16/Diffusion-Edge/diffusion_edge_natrual.pt](https://huggingface.co/hr16/Diffusion-Edge/blob/main/diffusion_edge_natrual.pt)
+* unimatch: Either [hr16/Unimatch/gmflow-scale2-regrefine6-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-regrefine6-mixdata.pth), [hr16/Unimatch/gmflow-scale2-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale2-mixdata.pth) or [hr16/Unimatch/gmflow-scale1-mixdata.pth](https://huggingface.co/hr16/Unimatch/blob/main/gmflow-scale1-mixdata.pth)
+* zoe_depth_anything: Either [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_indoor.pt) or [LiheYoung/Depth-Anything/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints_metric_depth/depth_anything_metric_depth_outdoor.pt)
+# 2000 Stars 😄
+
+
+
+
+
+
+
+
+Thanks for yalls supports. I never thought the graph for stars would be linear lol.
diff --git a/UPDATES.md b/UPDATES.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d5a8f622a72e536d30d86838dc9770612787ebd
--- /dev/null
+++ b/UPDATES.md
@@ -0,0 +1,45 @@
+* `AIO Aux Preprocessor` intergrating all loadable aux preprocessors as dropdown options. Easy to copy, paste and get the preprocessor faster.
+* Added OpenPose-format JSON output from OpenPose Preprocessor and DWPose Preprocessor. Checks [here](#faces-and-poses).
+* Fixed wrong model path when downloading DWPose.
+* Make hint images less blurry.
+* Added `resolution` option, `PixelPerfectResolution` and `HintImageEnchance` nodes (TODO: Documentation).
+* Added `RAFT Optical Flow Embedder` for TemporalNet2 (TODO: Workflow example).
+* Fixed opencv's conflicts between this extension, [ReActor](https://github.com/Gourieff/comfyui-reactor-node) and Roop. Thanks `Gourieff` for [the solution](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/7#issuecomment-1734319075)!
+* RAFT is removed as the code behind it doesn't match what what the original code does
+* Changed `lineart`'s display name from `Normal Lineart` to `Realistic Lineart`. This change won't affect old workflows
+* Added support for `onnxruntime` to speed-up DWPose (see the Q&A)
+* Fixed TypeError: expected size to be one of int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], but got size with types [, ]: [Issue](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2), [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/71))
+* Fixed ImageGenResolutionFromImage mishape (https://github.com/Fannovel16/comfyui_controlnet_aux/pull/74)
+* Fixed LeRes and MiDaS's incomatipility with MPS device
+* Fixed checking DWPose onnxruntime session multiple times: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/89)
+* Added `Anime Face Segmentor` (in `ControlNet Preprocessors/Semantic Segmentation`) for [ControlNet AnimeFaceSegmentV2](https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite#animefacesegmentv2). Checks [here](#anime-face-segmentor)
+* Change download functions and fix [download error](https://github.com/Fannovel16/comfyui_controlnet_aux/issues/39): [PR](https://github.com/Fannovel16/comfyui_controlnet_aux/pull/96)
+* Caching DWPose Onnxruntime during the first use of DWPose node instead of ComfyUI startup
+* Added alternative YOLOX models for faster speed when using DWPose
+* Added alternative DWPose models
+* Implemented the preprocessor for [AnimalPose ControlNet](https://github.com/abehonest/ControlNet_AnimalPose/tree/main). Check [Animal Pose AP-10K](#animal-pose-ap-10k)
+* Added YOLO-NAS models which are drop-in replacements of YOLOX
+* Fixed Openpose Face/Hands no longer detecting: https://github.com/Fannovel16/comfyui_controlnet_aux/issues/54
+* Added TorchScript implementation of DWPose and AnimalPose
+* Added TorchScript implementation of DensePose from [Colab notebook](https://colab.research.google.com/drive/16hcaaKs210ivpxjoyGNuvEXZD4eqOOSQ) which doesn't require detectron2. [Example](#densepose). Thanks [@LayerNome](https://github.com/Layer-norm) for fixing bugs related.
+* Added Standard Lineart Preprocessor
+* Fixed OpenPose misplacements in some cases
+* Added Mesh Graphormer - Hand Depth Map & Mask
+* Misaligned hands bug from MeshGraphormer was fixed
+* Added more mask options for MeshGraphormer
+* Added Save Pose Keypoint node for editing
+* Added Unimatch Optical Flow
+* Added Depth Anything & Zoe Depth Anything
+* Removed resolution field from Unimatch Optical Flow as that interpolating optical flow seems unstable
+* Added TEED Soft-Edge Preprocessor
+* Added DiffusionEdge
+* Added Image Luminance and Image Intensity
+* Added Normal DSINE
+* Added TTPlanet Tile (09/05/2024, DD/MM/YYYY)
+* Added AnyLine, Metric3D (18/05/2024)
+* Added Depth Anything V2 (16/06/2024)
+* Added Union model of ControlNet and preprocessors
+
+* Refactor INPUT_TYPES and add Execute All node during the process of learning [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/pull/2666)
+* Added scale_stick_for_xinsr_cn (https://github.com/Fannovel16/comfyui_controlnet_aux/issues/447) (09/04/2024)
+* PyTorch 2.7 compatibility fixes - eliminated custom_timm, custom_detectron2, and custom_midas_repo dependencies causing hanging issues. Refactored 7 major preprocessors including OneFormer (now using HuggingFace transformers), ZOE, DSINE, MiDaS, BAE, Metric3D, and Uniformer. Resolved ~59 GitHub issues related to import failures, hanging, and extension conflicts. Full modernization to actively maintained packages.
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de222b1bcab8a10fc4da2af24b305381cea8002
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,224 @@
+import sys, os
+
+# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
+# Must be set BEFORE any MMCV imports happen anywhere in ComfyUI
+os.environ['NPU_DEVICE_COUNT'] = '0'
+os.environ['MMCV_WITH_OPS'] = '0'
+from .utils import here, define_preprocessor_inputs, INPUT
+from pathlib import Path
+import traceback
+import importlib
+from .log import log, blue_text, cyan_text, get_summary, get_label
+from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS
+from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS
+#Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17
+sys.path.insert(0, str(Path(here, "src").resolve()))
+for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]:
+ sys.path.append(str(Path(here, "src", pkg_name).resolve()))
+
+#Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out
+#https://github.com/pytorch/pytorch/issues/77764
+#https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1')
+
+
+def load_nodes():
+ shorted_errors = []
+ full_error_messages = []
+ node_class_mappings = {}
+ node_display_name_mappings = {}
+
+ for filename in (here / "node_wrappers").iterdir():
+ module_name = filename.stem
+ if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store))
+ try:
+ module = importlib.import_module(
+ f".node_wrappers.{module_name}", package=__package__
+ )
+ node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS"))
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
+ node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS"))
+
+ log.debug(f"Imported {module_name} nodes")
+
+ except AttributeError:
+ pass # wip nodes
+ except Exception:
+ error_message = traceback.format_exc()
+ full_error_messages.append(error_message)
+ error_message = error_message.splitlines()[-1]
+ shorted_errors.append(
+ f"Failed to import module {module_name} because {error_message}"
+ )
+
+ if len(shorted_errors) > 0:
+ full_err_log = '\n\n'.join(full_error_messages)
+ print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n")
+ log.info(
+ f"Some nodes failed to load:\n\t"
+ + "\n\t".join(shorted_errors)
+ + "\n\n"
+ + "Check that you properly installed the dependencies.\n"
+ + "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)"
+ )
+ return node_class_mappings, node_display_name_mappings
+
+AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes()
+
+#For nodes not mapping image to image or has special requirements
+AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"]
+AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"]
+AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"]
+
+def preprocessor_options():
+ auxs = list(AUX_NODE_MAPPINGS.keys())
+ auxs.insert(0, "none")
+ for name in AIO_NOT_SUPPORTED:
+ if name in auxs:
+ auxs.remove(name)
+ return auxs
+
+
+PREPROCESSOR_OPTIONS = preprocessor_options()
+
+class AIO_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def execute(self, preprocessor, image, resolution=512):
+ if preprocessor == "none":
+ return (image, )
+ else:
+ aux_class = AUX_NODE_MAPPINGS[preprocessor]
+ input_types = aux_class.INPUT_TYPES()
+ input_types = {
+ **input_types["required"],
+ **(input_types["optional"] if "optional" in input_types else {})
+ }
+ params = {}
+ for name, input_type in input_types.items():
+ if name == "image":
+ params[name] = image
+ continue
+
+ if name == "resolution":
+ params[name] = resolution
+ continue
+
+ if len(input_type) == 2 and ("default" in input_type[1]):
+ params[name] = input_type[1]["default"]
+ continue
+
+ default_values = { "INT": 0, "FLOAT": 0.0 }
+ if type(input_type[0]) is list:
+ for input_type_value in input_type[0]:
+ if input_type_value in default_values:
+ params[name] = default_values[input_type[0]]
+ else:
+ if input_type[0] in default_values:
+ params[name] = default_values[input_type[0]]
+
+ return getattr(aux_class(), aux_class.FUNCTION)(**params)
+
+class ControlNetAuxSimpleAddText:
+ @classmethod
+ def INPUT_TYPES(s):
+ return dict(
+ required=dict(image=INPUT.IMAGE(), text=INPUT.STRING())
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+ CATEGORY = "ControlNet Preprocessors"
+ def execute(self, image, text):
+ from PIL import Image, ImageDraw, ImageFont
+ import numpy as np
+ import torch
+
+ font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40)
+ img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8))
+ ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font)
+ return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,)
+
+class ExecuteAllControlNetPreprocessors:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def execute(self, image, resolution=512):
+ try:
+ from comfy_execution.graph_utils import GraphBuilder
+ except:
+ raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature")
+
+ graph = GraphBuilder()
+ curr_outputs = []
+ for preprocc in PREPROCESSOR_OPTIONS:
+ preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution)
+ hint_img = preprocc_node.out(0)
+ add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc)
+ curr_outputs.append(add_text_node.out(0))
+
+ while len(curr_outputs) > 1:
+ _outputs = []
+ for i in range(0, len(curr_outputs), 2):
+ if i+1 < len(curr_outputs):
+ image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1])
+ _outputs.append(image_batch.out(0))
+ else:
+ _outputs.append(curr_outputs[i])
+ curr_outputs = _outputs
+
+ return {
+ "result": (curr_outputs[0],),
+ "expand": graph.finalize(),
+ }
+
+class ControlNetPreprocessorSelector:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "preprocessor": (PREPROCESSOR_OPTIONS,),
+ }
+ }
+
+ RETURN_TYPES = (PREPROCESSOR_OPTIONS,)
+ RETURN_NAMES = ("preprocessor",)
+ FUNCTION = "get_preprocessor"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def get_preprocessor(self, preprocessor: str):
+ return (preprocessor,)
+
+
+NODE_CLASS_MAPPINGS = {
+ **AUX_NODE_MAPPINGS,
+ "AIO_Preprocessor": AIO_Preprocessor,
+ "ControlNetPreprocessorSelector": ControlNetPreprocessorSelector,
+ **HIE_NODE_CLASS_MAPPINGS,
+ "ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors,
+ "ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ **AUX_DISPLAY_NAME_MAPPINGS,
+ "AIO_Preprocessor": "AIO Aux Preprocessor",
+ "ControlNetPreprocessorSelector": "Preprocessor Selector",
+ **HIE_NODE_DISPLAY_NAME_MAPPINGS,
+ "ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors"
+}
diff --git a/config.example.yaml b/config.example.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6734ad02ba275e4f3dd820e448f841a4cd156dd3
--- /dev/null
+++ b/config.example.yaml
@@ -0,0 +1,20 @@
+# this is an example for config.yaml file, you can rename it to config.yaml if you want to use it
+# ###############################################################################################
+# This path is for custom pressesor models base folder. default is "./ckpts"
+# you can also use absolute paths like: "/root/ComfyUI/custom_nodes/comfyui_controlnet_aux/ckpts" or "D:\\ComfyUI\\custom_nodes\\comfyui_controlnet_aux\\ckpts"
+annotator_ckpts_path: "./ckpts"
+# ###############################################################################################
+# This path is for downloading temporary files.
+# You SHOULD use absolute path for this like"D:\\temp", DO NOT use relative paths. Empty for default.
+custom_temp_path:
+# ###############################################################################################
+# if you already have downloaded ckpts via huggingface hub into default cache path like: ~/.cache/huggingface/hub, you can set this True to use symlinks to save space
+USE_SYMLINKS: False
+# ###############################################################################################
+# EP_list is a list of execution providers for onnxruntime, if one of them is not available or not working well, you can delete that provider from here(config.yaml)
+# you can find all available providers here: https://onnxruntime.ai/docs/execution-providers
+# for example, if you have CUDA installed, you can set it to: ["CUDAExecutionProvider", "CPUExecutionProvider"]
+# empty list or only keep ["CPUExecutionProvider"] means you use cv2.dnn.readNetFromONNX to load onnx models
+# if your onnx models can only run on the CPU or have other issues, we recommend using pt model instead.
+# default value is ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
+EP_list: ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
diff --git a/dev_interface.py b/dev_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c08abdbbbdfa154f3b09c64d63ce7c517491337
--- /dev/null
+++ b/dev_interface.py
@@ -0,0 +1,6 @@
+from pathlib import Path
+from utils import here
+import sys
+sys.path.append(str(Path(here, "src")))
+
+from custom_controlnet_aux import *
\ No newline at end of file
diff --git a/examples/CNAuxBanner.jpg b/examples/CNAuxBanner.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..42621fb16301b731b9f8853b77cef9f3a209d370
--- /dev/null
+++ b/examples/CNAuxBanner.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee78fe95215ac7cedec155f7df25707858b11add1f15c8a06e978deeef8a1666
+size 589690
diff --git a/examples/ExecuteAll.png b/examples/ExecuteAll.png
new file mode 100644
index 0000000000000000000000000000000000000000..238e96e327b8a0a2c738591f91a0a77b96a7909c
--- /dev/null
+++ b/examples/ExecuteAll.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9e9d1f3cb4d13005818cf7c14e04be0635b44e180776dce7f02f715e246d18e
+size 10007102
diff --git a/examples/ExecuteAll1.jpg b/examples/ExecuteAll1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c01db1275c3e6d8a334d6d5b38f9c6cafbd4af1d
--- /dev/null
+++ b/examples/ExecuteAll1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6750043f6866ec52ae348ce108f8b7361c6e30a744cef162f289bbb2296cdad9
+size 1171712
diff --git a/examples/ExecuteAll2.jpg b/examples/ExecuteAll2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..03796e2d40734f108d454bfaa073f7dfe82d239a
--- /dev/null
+++ b/examples/ExecuteAll2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:281478f1d39ab9d2ba8b0aa5d3ab33a34c4dc993e078b3bcca9ffbf024f2505b
+size 1021442
diff --git a/examples/comfyui-controlnet-aux-logo.png b/examples/comfyui-controlnet-aux-logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..28082eb86a7cf726c9d783fa99e8cd1a980ea232
--- /dev/null
+++ b/examples/comfyui-controlnet-aux-logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e564b8171c2cdd6a3056c079dd0e096cfcd877225c318c63962500f13c313c8
+size 710215
diff --git a/examples/example_animal_pose.png b/examples/example_animal_pose.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c2e34821a9f173dd35b2203347bd4b5eb9cce0f
--- /dev/null
+++ b/examples/example_animal_pose.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53d5485dfe3c1d512773e3fbcb063b8cb5ff2a9b4d9bd814d41120e4636d7711
+size 722786
diff --git a/examples/example_anime_face_segmentor.png b/examples/example_anime_face_segmentor.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c8a589a8037a5b86769167fba475f460a5d90c8
--- /dev/null
+++ b/examples/example_anime_face_segmentor.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:561a225461ba56ffb33491dab0bd2b9d053cffb21c98a46a4281f53725a0f863
+size 483000
diff --git a/examples/example_anyline.png b/examples/example_anyline.png
new file mode 100644
index 0000000000000000000000000000000000000000..9369205d566123d59fe72b18995805f093f9962b
--- /dev/null
+++ b/examples/example_anyline.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28e2ffe0c96d7c7d44c45ff10c6754ef9741c638caa06c619774f82a0d4e12c5
+size 379500
diff --git a/examples/example_densepose.png b/examples/example_densepose.png
new file mode 100644
index 0000000000000000000000000000000000000000..bbfc71c830280ab21041bec457b4d41675f461f2
--- /dev/null
+++ b/examples/example_densepose.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89e5840fe9ba3be4cecd5842de7c16fd36d36b5ded597c316a4479f668e6e420
+size 277796
diff --git a/examples/example_depth_anything.png b/examples/example_depth_anything.png
new file mode 100644
index 0000000000000000000000000000000000000000..733e4a974c731fc7b2e5b2b33aebb3f7c531a032
--- /dev/null
+++ b/examples/example_depth_anything.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c6da6f0b1a4656a193a1797c2b235889c45cd8f08ae4405ebf95d1fe6d2aa70
+size 448662
diff --git a/examples/example_depth_anything_v2.png b/examples/example_depth_anything_v2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b47e60a901e22302d6f20234714cd14cd3d484b9
--- /dev/null
+++ b/examples/example_depth_anything_v2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b3f78b0ba41d13f8097722f936822f81acf1df0b318d2319c352dab7576c03b
+size 184526
diff --git a/examples/example_dsine.png b/examples/example_dsine.png
new file mode 100644
index 0000000000000000000000000000000000000000..98eeed056a32fae7635c5fc735a5101d34387ca7
--- /dev/null
+++ b/examples/example_dsine.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f136a86046455848b35847cabf0f0801c37b117d52beebdbdb21f5dccfde3291
+size 651038
diff --git a/examples/example_marigold.png b/examples/example_marigold.png
new file mode 100644
index 0000000000000000000000000000000000000000..c0f916bbff15708ccb4c927a37289ddc9bbea0c3
--- /dev/null
+++ b/examples/example_marigold.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3da4c5a5e934e1912b67a737df9aec3b12b893403b650545bdc69b06a006038
+size 661365
diff --git a/examples/example_marigold_flat.jpg b/examples/example_marigold_flat.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2f304371400aed1bdbb1894aea76f7a5d2ba79e
--- /dev/null
+++ b/examples/example_marigold_flat.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7182b223c53093d28cfd6c34e3d26083658eefac0bc22f0f88a801b589dcdee1
+size 300798
diff --git a/examples/example_mesh_graphormer.png b/examples/example_mesh_graphormer.png
new file mode 100644
index 0000000000000000000000000000000000000000..90b162d7c91beda88ccf81a4358ba44914596fd2
--- /dev/null
+++ b/examples/example_mesh_graphormer.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7692c5df1ee107b95c02455eea4e88f878b59313a98b5447736c9c417c0e182
+size 5481152
diff --git a/examples/example_metric3d.png b/examples/example_metric3d.png
new file mode 100644
index 0000000000000000000000000000000000000000..be3bbeb0ed4ab2a6848c227e3d95600601064a08
--- /dev/null
+++ b/examples/example_metric3d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:080c52e5f27d1cba63368f6ad277b2a2f9b13433e6d8621436f5ad926157917d
+size 249587
diff --git a/examples/example_onnx.png b/examples/example_onnx.png
new file mode 100644
index 0000000000000000000000000000000000000000..f3f9ad5a45e2ce33b03883446b55fb487c059a00
Binary files /dev/null and b/examples/example_onnx.png differ
diff --git a/examples/example_recolor.png b/examples/example_recolor.png
new file mode 100644
index 0000000000000000000000000000000000000000..4750a3396835adb520620931ef3f73cd3d58ab62
--- /dev/null
+++ b/examples/example_recolor.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc0eecdecf9f698dbca0b38e08834cfd5f232c12a11b63945e75e6af7b3ce366
+size 730211
diff --git a/examples/example_save_kps.png b/examples/example_save_kps.png
new file mode 100644
index 0000000000000000000000000000000000000000..c6b97029d486eefa180b18bae209cf44678fd78c
--- /dev/null
+++ b/examples/example_save_kps.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec9f6d86d0d620458e6f8b85a4e7798525c5d50b314e81b7e8234182edc3a62b
+size 409884
diff --git a/examples/example_teed.png b/examples/example_teed.png
new file mode 100644
index 0000000000000000000000000000000000000000..520d2fe6e5cdf1ab0a9db4a1e31001ab4a96cb29
--- /dev/null
+++ b/examples/example_teed.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aea1e8799a1d600a64a900dc8b71b07f30e86e30f12777db3e5dec7039afa7ce
+size 531265
diff --git a/examples/example_torchscript.png b/examples/example_torchscript.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d50d424372d01015dd0a91fa4656823237f7de1
--- /dev/null
+++ b/examples/example_torchscript.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:056c3a6d78503b3d4aa4a6556e53498b34593b0220d7d447f654304ac0209fa6
+size 109440
diff --git a/examples/example_unimatch.png b/examples/example_unimatch.png
new file mode 100644
index 0000000000000000000000000000000000000000..60f77d14ccea695774a51b613938ba7ea778b04a
--- /dev/null
+++ b/examples/example_unimatch.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a4b0f71140532e0a0c679bdff9023f5e01012df1294eafffc08ced5848acb8
+size 249664
diff --git a/hint_image_enchance.py b/hint_image_enchance.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb2a06974d2e23e53218706248d6340ed99de61a
--- /dev/null
+++ b/hint_image_enchance.py
@@ -0,0 +1,233 @@
+from .log import log
+from .utils import ResizeMode, safe_numpy
+import numpy as np
+import torch
+import cv2
+from .utils import get_unique_axis0
+from .lvminthin import nake_nms, lvmin_thin
+
+MAX_IMAGEGEN_RESOLUTION = 8192 #https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L42
+RESIZE_MODES = [ResizeMode.RESIZE.value, ResizeMode.INNER_FIT.value, ResizeMode.OUTER_FIT.value]
+
+#Port from https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/internal_controlnet/external_code.py#L89
+class PixelPerfectResolution:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_image": ("IMAGE", ),
+ "image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
+ "image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
+ #https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
+ "resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
+ }
+ }
+
+ RETURN_TYPES = ("INT",)
+ RETURN_NAMES = ("RESOLUTION (INT)", )
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def execute(self, original_image, image_gen_width, image_gen_height, resize_mode):
+ _, raw_H, raw_W, _ = original_image.shape
+
+ k0 = float(image_gen_height) / float(raw_H)
+ k1 = float(image_gen_width) / float(raw_W)
+
+ if resize_mode == ResizeMode.OUTER_FIT.value:
+ estimation = min(k0, k1) * float(min(raw_H, raw_W))
+ else:
+ estimation = max(k0, k1) * float(min(raw_H, raw_W))
+
+ log.debug(f"Pixel Perfect Computation:")
+ log.debug(f"resize_mode = {resize_mode}")
+ log.debug(f"raw_H = {raw_H}")
+ log.debug(f"raw_W = {raw_W}")
+ log.debug(f"target_H = {image_gen_height}")
+ log.debug(f"target_W = {image_gen_width}")
+ log.debug(f"estimation = {estimation}")
+
+ return (int(np.round(estimation)), )
+
+class HintImageEnchance:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "hint_image": ("IMAGE", ),
+ "image_gen_width": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
+ "image_gen_height": ("INT", {"default": 512, "min": 64, "max": MAX_IMAGEGEN_RESOLUTION, "step": 8}),
+ #https://github.com/comfyanonymous/ComfyUI/blob/c910b4a01ca58b04e5d4ab4c747680b996ada02b/nodes.py#L854
+ "resize_mode": (RESIZE_MODES, {"default": ResizeMode.RESIZE.value})
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+ def execute(self, hint_image, image_gen_width, image_gen_height, resize_mode):
+ outs = []
+ for single_hint_image in hint_image:
+ np_hint_image = np.asarray(single_hint_image * 255., dtype=np.uint8)
+
+ if resize_mode == ResizeMode.RESIZE.value:
+ np_hint_image = self.execute_resize(np_hint_image, image_gen_width, image_gen_height)
+ elif resize_mode == ResizeMode.OUTER_FIT.value:
+ np_hint_image = self.execute_outer_fit(np_hint_image, image_gen_width, image_gen_height)
+ else:
+ np_hint_image = self.execute_inner_fit(np_hint_image, image_gen_width, image_gen_height)
+
+ outs.append(torch.from_numpy(np_hint_image.astype(np.float32) / 255.0))
+
+ return (torch.stack(outs, dim=0),)
+
+ def execute_resize(self, detected_map, w, h):
+ detected_map = self.high_quality_resize(detected_map, (w, h))
+ detected_map = safe_numpy(detected_map)
+ return detected_map
+
+ def execute_outer_fit(self, detected_map, w, h):
+ old_h, old_w, _ = detected_map.shape
+ old_w = float(old_w)
+ old_h = float(old_h)
+ k0 = float(h) / old_h
+ k1 = float(w) / old_w
+ safeint = lambda x: int(np.round(x))
+ k = min(k0, k1)
+
+ borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
+ high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
+ if len(high_quality_border_color) == 4:
+ # Inpaint hijack
+ high_quality_border_color[3] = 255
+ high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
+ detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
+ new_h, new_w, _ = detected_map.shape
+ pad_h = max(0, (h - new_h) // 2)
+ pad_w = max(0, (w - new_w) // 2)
+ high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
+ detected_map = high_quality_background
+ detected_map = safe_numpy(detected_map)
+ return detected_map
+
+ def execute_inner_fit(self, detected_map, w, h):
+ old_h, old_w, _ = detected_map.shape
+ old_w = float(old_w)
+ old_h = float(old_h)
+ k0 = float(h) / old_h
+ k1 = float(w) / old_w
+ safeint = lambda x: int(np.round(x))
+ k = max(k0, k1)
+
+ detected_map = self.high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
+ new_h, new_w, _ = detected_map.shape
+ pad_h = max(0, (new_h - h) // 2)
+ pad_w = max(0, (new_w - w) // 2)
+ detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
+ detected_map = safe_numpy(detected_map)
+ return detected_map
+
+ def high_quality_resize(self, x, size):
+ # Written by lvmin
+ # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
+
+ inpaint_mask = None
+ if x.ndim == 3 and x.shape[2] == 4:
+ inpaint_mask = x[:, :, 3]
+ x = x[:, :, 0:3]
+
+ if x.shape[0] != size[1] or x.shape[1] != size[0]:
+ new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
+ new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
+ unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2])))
+ is_one_pixel_edge = False
+ is_binary = False
+ if unique_color_count == 2:
+ is_binary = np.min(x) < 16 and np.max(x) > 240
+ if is_binary:
+ xc = x
+ xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
+ xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
+ one_pixel_edge_count = np.where(xc < x)[0].shape[0]
+ all_edge_count = np.where(x > 127)[0].shape[0]
+ is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
+
+ if 2 < unique_color_count < 200:
+ interpolation = cv2.INTER_NEAREST
+ elif new_size_is_smaller:
+ interpolation = cv2.INTER_AREA
+ else:
+ interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
+
+ y = cv2.resize(x, size, interpolation=interpolation)
+ if inpaint_mask is not None:
+ inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
+
+ if is_binary:
+ y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
+ if is_one_pixel_edge:
+ y = nake_nms(y)
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+ y = lvmin_thin(y, prunings=new_size_is_bigger)
+ else:
+ _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+ y = np.stack([y] * 3, axis=2)
+ else:
+ y = x
+
+ if inpaint_mask is not None:
+ inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
+ inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
+ y = np.concatenate([y, inpaint_mask], axis=2)
+
+ return y
+
+
+class ImageGenResolutionFromLatent:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": { "latent": ("LATENT", ) }
+ }
+
+ RETURN_TYPES = ("INT", "INT")
+ RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def execute(self, latent):
+ _, _, H, W = latent["samples"].shape
+ return (W * 8, H * 8)
+
+class ImageGenResolutionFromImage:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": { "image": ("IMAGE", ) }
+ }
+
+ RETURN_TYPES = ("INT", "INT")
+ RETURN_NAMES = ("IMAGE_GEN_WIDTH (INT)", "IMAGE_GEN_HEIGHT (INT)")
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors"
+
+ def execute(self, image):
+ _, H, W, _ = image.shape
+ return (W, H)
+
+NODE_CLASS_MAPPINGS = {
+ "PixelPerfectResolution": PixelPerfectResolution,
+ "ImageGenResolutionFromImage": ImageGenResolutionFromImage,
+ "ImageGenResolutionFromLatent": ImageGenResolutionFromLatent,
+ "HintImageEnchance": HintImageEnchance
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PixelPerfectResolution": "Pixel Perfect Resolution",
+ "ImageGenResolutionFromImage": "Generation Resolution From Image",
+ "ImageGenResolutionFromLatent": "Generation Resolution From Latent",
+ "HintImageEnchance": "Enchance And Resize Hint Images"
+}
\ No newline at end of file
diff --git a/install.bat b/install.bat
new file mode 100644
index 0000000000000000000000000000000000000000..c36a67448534a5febc2a83d4ebef4cfa49fa6deb
--- /dev/null
+++ b/install.bat
@@ -0,0 +1,20 @@
+@echo off
+
+set "requirements_txt=%~dp0\requirements.txt"
+set "python_exec=..\..\..\python_embeded\python.exe"
+
+echo Installing ComfyUI's ControlNet Auxiliary Preprocessors..
+
+if exist "%python_exec%" (
+ echo Installing with ComfyUI Portable
+ for /f "delims=" %%i in (%requirements_txt%) do (
+ %python_exec% -s -m pip install "%%i"
+ )
+) else (
+ echo Installing with system Python
+ for /f "delims=" %%i in (%requirements_txt%) do (
+ pip install "%%i"
+ )
+)
+
+pause
\ No newline at end of file
diff --git a/log.py b/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..2978c6dc770c78feddc2d0dece7c0b6a91ed23f0
--- /dev/null
+++ b/log.py
@@ -0,0 +1,80 @@
+#Cre: https://github.com/melMass/comfy_mtb/blob/main/log.py
+import logging
+import re
+import os
+
+base_log_level = logging.INFO
+
+
+# Custom object that discards the output
+class NullWriter:
+ def write(self, text):
+ pass
+
+
+class Formatter(logging.Formatter):
+ grey = "\x1b[38;20m"
+ cyan = "\x1b[36;20m"
+ purple = "\x1b[35;20m"
+ yellow = "\x1b[33;20m"
+ red = "\x1b[31;20m"
+ bold_red = "\x1b[31;1m"
+ reset = "\x1b[0m"
+ # format = "%(asctime)s - [%(name)s] - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
+ format = "[%(name)s] | %(levelname)s -> %(message)s"
+
+ FORMATS = {
+ logging.DEBUG: purple + format + reset,
+ logging.INFO: cyan + format + reset,
+ logging.WARNING: yellow + format + reset,
+ logging.ERROR: red + format + reset,
+ logging.CRITICAL: bold_red + format + reset,
+ }
+
+ def format(self, record):
+ log_fmt = self.FORMATS.get(record.levelno)
+ formatter = logging.Formatter(log_fmt)
+ return formatter.format(record)
+
+
+def mklog(name, level=base_log_level):
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+
+ for handler in logger.handlers:
+ logger.removeHandler(handler)
+
+ ch = logging.StreamHandler()
+ ch.setLevel(level)
+ ch.setFormatter(Formatter())
+ logger.addHandler(ch)
+
+ # Disable log propagation
+ logger.propagate = False
+
+ return logger
+
+
+# - The main app logger
+log = mklog(__package__, base_log_level)
+
+
+def log_user(arg):
+ print("\033[34mComfyUI ControlNet AUX:\033[0m {arg}")
+
+
+def get_summary(docstring):
+ return docstring.strip().split("\n\n", 1)[0]
+
+
+def blue_text(text):
+ return f"\033[94m{text}\033[0m"
+
+
+def cyan_text(text):
+ return f"\033[96m{text}\033[0m"
+
+
+def get_label(label):
+ words = re.findall(r"(?:^|[A-Z])[a-z]*", label)
+ return " ".join(words).strip()
\ No newline at end of file
diff --git a/lvminthin.py b/lvminthin.py
new file mode 100644
index 0000000000000000000000000000000000000000..eebe0fb6d4967d0f6c38c0117ce4775d46d0be08
--- /dev/null
+++ b/lvminthin.py
@@ -0,0 +1,87 @@
+# High Quality Edge Thinning using Pure Python
+# Written by Lvmin Zhang
+# 2023 April
+# Stanford University
+# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
+
+
+import cv2
+import numpy as np
+
+
+lvmin_kernels_raw = [
+ np.array([
+ [-1, -1, -1],
+ [0, 1, 0],
+ [1, 1, 1]
+ ], dtype=np.int32),
+ np.array([
+ [0, -1, -1],
+ [1, 1, -1],
+ [0, 1, 0]
+ ], dtype=np.int32)
+]
+
+lvmin_kernels = []
+lvmin_kernels += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
+lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
+
+lvmin_prunings_raw = [
+ np.array([
+ [-1, -1, -1],
+ [-1, 1, -1],
+ [0, 0, -1]
+ ], dtype=np.int32),
+ np.array([
+ [-1, -1, -1],
+ [-1, 1, -1],
+ [-1, 0, 0]
+ ], dtype=np.int32)
+]
+
+lvmin_prunings = []
+lvmin_prunings += [np.rot90(x, k=0, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=1, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_prunings_raw]
+lvmin_prunings += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_prunings_raw]
+
+
+def remove_pattern(x, kernel):
+ objects = cv2.morphologyEx(x, cv2.MORPH_HITMISS, kernel)
+ objects = np.where(objects > 127)
+ x[objects] = 0
+ return x, objects[0].shape[0] > 0
+
+
+def thin_one_time(x, kernels):
+ y = x
+ is_done = True
+ for k in kernels:
+ y, has_update = remove_pattern(y, k)
+ if has_update:
+ is_done = False
+ return y, is_done
+
+
+def lvmin_thin(x, prunings=True):
+ y = x
+ for i in range(32):
+ y, is_done = thin_one_time(y, lvmin_kernels)
+ if is_done:
+ break
+ if prunings:
+ y, _ = thin_one_time(y, lvmin_prunings)
+ return y
+
+
+def nake_nms(x):
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+ y = np.zeros_like(x)
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+ return y
diff --git a/node_wrappers/anime_face_segment.py b/node_wrappers/anime_face_segment.py
new file mode 100644
index 0000000000000000000000000000000000000000..e642aa08b09ca85ccf8b61aa98531f862e8b3075
--- /dev/null
+++ b/node_wrappers/anime_face_segment.py
@@ -0,0 +1,43 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+import torch
+from einops import rearrange
+
+class AnimeFace_SemSegPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ #This preprocessor is only trained on 512x resolution
+ #https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/predict.py#L25
+ return define_preprocessor_inputs(
+ remove_background_using_abg=INPUT.BOOLEAN(True),
+ resolution=INPUT.RESOLUTION(default=512, min=512, max=512)
+ )
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ RETURN_NAMES = ("IMAGE", "ABG_CHARACTER_MASK (MASK)")
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
+
+ def execute(self, image, remove_background_using_abg=True, resolution=512, **kwargs):
+ from custom_controlnet_aux.anime_face_segment import AnimeFaceSegmentor
+
+ model = AnimeFaceSegmentor.from_pretrained().to(model_management.get_torch_device())
+ if remove_background_using_abg:
+ out_image_with_mask = common_annotator_call(model, image, resolution=resolution, remove_background=True)
+ out_image = out_image_with_mask[..., :3]
+ mask = out_image_with_mask[..., 3:]
+ mask = rearrange(mask, "n h w c -> n c h w")
+ else:
+ out_image = common_annotator_call(model, image, resolution=resolution, remove_background=False)
+ N, H, W, C = out_image.shape
+ mask = torch.ones(N, C, H, W)
+ del model
+ return (out_image, mask)
+
+NODE_CLASS_MAPPINGS = {
+ "AnimeFace_SemSegPreprocessor": AnimeFace_SemSegPreprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "AnimeFace_SemSegPreprocessor": "Anime Face Segmentor"
+}
\ No newline at end of file
diff --git a/node_wrappers/anyline.py b/node_wrappers/anyline.py
new file mode 100644
index 0000000000000000000000000000000000000000..187e90a80eca12ccc1df5b1aa4482e255b97b363
--- /dev/null
+++ b/node_wrappers/anyline.py
@@ -0,0 +1,87 @@
+import torch
+import numpy as np
+import comfy.model_management as model_management
+import comfy.utils
+
+# Requires comfyui_controlnet_aux funcsions and classes
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+
+def get_intensity_mask(image_array, lower_bound, upper_bound):
+ mask = image_array[:, :, 0]
+ mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
+ mask = np.expand_dims(mask, 2).repeat(3, axis=2)
+ return mask
+
+def combine_layers(base_layer, top_layer):
+ mask = top_layer.astype(bool)
+ temp = 1 - (1 - top_layer) * (1 - base_layer)
+ result = base_layer * (~mask) + temp * mask
+ return result
+
+class AnyLinePreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ merge_with_lineart=INPUT.COMBO(["lineart_standard", "lineart_realisitic", "lineart_anime", "manga_line"], default="lineart_standard"),
+ resolution=INPUT.RESOLUTION(default=1280, step=8),
+ lineart_lower_bound=INPUT.FLOAT(default=0),
+ lineart_upper_bound=INPUT.FLOAT(default=1),
+ object_min_size=INPUT.INT(default=36, min=1),
+ object_connectivity=INPUT.INT(default=1, min=1)
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+
+ FUNCTION = "get_anyline"
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def __init__(self):
+ self.device = model_management.get_torch_device()
+
+ def get_anyline(self, image, merge_with_lineart="lineart_standard", resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1):
+ from custom_controlnet_aux.teed import TEDDetector
+ from skimage import morphology
+ pbar = comfy.utils.ProgressBar(3)
+
+ # Process the image with MTEED model
+ mteed_model = TEDDetector.from_pretrained("TheMistoAI/MistoLine", "MTEED.pth", subfolder="Anyline").to(self.device)
+ mteed_result = common_annotator_call(mteed_model, image, resolution=resolution, show_pbar=False)
+ mteed_result = mteed_result.numpy()
+ del mteed_model
+ pbar.update(1)
+
+ # Process the image with the lineart standard preprocessor
+ if merge_with_lineart == "lineart_standard":
+ from custom_controlnet_aux.lineart_standard import LineartStandardDetector
+ lineart_standard_detector = LineartStandardDetector()
+ lineart_result = common_annotator_call(lineart_standard_detector, image, guassian_sigma=2, intensity_threshold=3, resolution=resolution, show_pbar=False).numpy()
+ del lineart_standard_detector
+ else:
+ from custom_controlnet_aux.lineart import LineartDetector
+ from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
+ from custom_controlnet_aux.manga_line import LineartMangaDetector
+ lineart_detector = dict(lineart_realisitic=LineartDetector, lineart_anime=LineartAnimeDetector, manga_line=LineartMangaDetector)[merge_with_lineart]
+ lineart_detector = lineart_detector.from_pretrained().to(self.device)
+ lineart_result = common_annotator_call(lineart_detector, image, resolution=resolution, show_pbar=False).numpy()
+ del lineart_detector
+ pbar.update(1)
+
+ final_result = []
+ for i in range(len(image)):
+ _lineart_result = get_intensity_mask(lineart_result[i], lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound)
+ _cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity)
+ _lineart_result = _lineart_result * _cleaned
+ _mteed_result = mteed_result[i]
+
+ # Combine the results
+ final_result.append(torch.from_numpy(combine_layers(_mteed_result, _lineart_result)))
+ pbar.update(1)
+ return (torch.stack(final_result),)
+
+NODE_CLASS_MAPPINGS = {
+ "AnyLineArtPreprocessor_aux": AnyLinePreprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "AnyLineArtPreprocessor_aux": "AnyLine Lineart"
+}
diff --git a/node_wrappers/binary.py b/node_wrappers/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3eb6be5808053a468bccd6b1bcfb678518eb90
--- /dev/null
+++ b/node_wrappers/binary.py
@@ -0,0 +1,29 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class Binary_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ bin_threshold=INPUT.INT(default=100, max=255),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, bin_threshold=100, resolution=512, **kwargs):
+ from custom_controlnet_aux.binary import BinaryDetector
+
+ return (common_annotator_call(BinaryDetector(), image, bin_threshold=bin_threshold, resolution=resolution), )
+
+
+
+NODE_CLASS_MAPPINGS = {
+ "BinaryPreprocessor": Binary_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "BinaryPreprocessor": "Binary Lines"
+}
\ No newline at end of file
diff --git a/node_wrappers/canny.py b/node_wrappers/canny.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcbe510e553e8eeaac037ebc461f09550486483f
--- /dev/null
+++ b/node_wrappers/canny.py
@@ -0,0 +1,30 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class Canny_Edge_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ low_threshold=INPUT.INT(default=100, max=255),
+ high_threshold=INPUT.INT(default=200, max=255),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, low_threshold=100, high_threshold=200, resolution=512, **kwargs):
+ from custom_controlnet_aux.canny import CannyDetector
+
+ return (common_annotator_call(CannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
+
+
+
+NODE_CLASS_MAPPINGS = {
+ "CannyEdgePreprocessor": Canny_Edge_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "CannyEdgePreprocessor": "Canny Edge"
+}
\ No newline at end of file
diff --git a/node_wrappers/color.py b/node_wrappers/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..cadb344e88ce909f2f5b5c8af2ada9baf212c2dc
--- /dev/null
+++ b/node_wrappers/color.py
@@ -0,0 +1,26 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class Color_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.color import ColorDetector
+
+ return (common_annotator_call(ColorDetector(), image, resolution=resolution), )
+
+
+
+NODE_CLASS_MAPPINGS = {
+ "ColorPreprocessor": Color_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "ColorPreprocessor": "Color Pallete"
+}
\ No newline at end of file
diff --git a/node_wrappers/densepose.py b/node_wrappers/densepose.py
new file mode 100644
index 0000000000000000000000000000000000000000..74df04be16370a253da208e398afd75456b44e46
--- /dev/null
+++ b/node_wrappers/densepose.py
@@ -0,0 +1,31 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class DensePose_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ model=INPUT.COMBO(["densepose_r50_fpn_dl.torchscript", "densepose_r101_fpn_dl.torchscript"]),
+ cmap=INPUT.COMBO(["Viridis (MagicAnimate)", "Parula (CivitAI)"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
+
+ def execute(self, image, model="densepose_r50_fpn_dl.torchscript", cmap="Viridis (MagicAnimate)", resolution=512):
+ from custom_controlnet_aux.densepose import DenseposeDetector
+ model = DenseposeDetector \
+ .from_pretrained(filename=model) \
+ .to(model_management.get_torch_device())
+ return (common_annotator_call(model, image, cmap="viridis" if "Viridis" in cmap else "parula", resolution=resolution), )
+
+
+NODE_CLASS_MAPPINGS = {
+ "DensePosePreprocessor": DensePose_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DensePosePreprocessor": "DensePose Estimator"
+}
\ No newline at end of file
diff --git a/node_wrappers/depth_anything.py b/node_wrappers/depth_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..97ed5b3ca09d7a8771c6b60a5bf927614863f5c4
--- /dev/null
+++ b/node_wrappers/depth_anything.py
@@ -0,0 +1,55 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class Depth_Anything_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ ckpt_name=INPUT.COMBO(
+ ["depth_anything_vitl14.pth", "depth_anything_vitb14.pth", "depth_anything_vits14.pth"]
+ ),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, ckpt_name="depth_anything_vitl14.pth", resolution=512, **kwargs):
+ from custom_controlnet_aux.depth_anything import DepthAnythingDetector
+
+ model = DepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+class Zoe_Depth_Anything_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ environment=INPUT.COMBO(["indoor", "outdoor"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, environment="indoor", resolution=512, **kwargs):
+ from custom_controlnet_aux.zoe import ZoeDepthAnythingDetector
+ ckpt_name = "depth_anything_metric_depth_indoor.pt" if environment == "indoor" else "depth_anything_metric_depth_outdoor.pt"
+ model = ZoeDepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "DepthAnythingPreprocessor": Depth_Anything_Preprocessor,
+ "Zoe_DepthAnythingPreprocessor": Zoe_Depth_Anything_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DepthAnythingPreprocessor": "Depth Anything",
+ "Zoe_DepthAnythingPreprocessor": "Zoe Depth Anything"
+}
\ No newline at end of file
diff --git a/node_wrappers/depth_anything_v2.py b/node_wrappers/depth_anything_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae697edb94edbe53739bc4fd41becb20665d87fb
--- /dev/null
+++ b/node_wrappers/depth_anything_v2.py
@@ -0,0 +1,56 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class Depth_Anything_V2_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ ckpt_name=INPUT.COMBO(
+ ["depth_anything_v2_vitg.pth", "depth_anything_v2_vitl.pth", "depth_anything_v2_vitb.pth", "depth_anything_v2_vits.pth"],
+ default="depth_anything_v2_vitl.pth"
+ ),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, ckpt_name="depth_anything_v2_vitl.pth", resolution=512, **kwargs):
+ from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
+
+ model = DepthAnythingV2Detector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, max_depth=1)
+ del model
+ return (out, )
+
+""" class Depth_Anything_Metric_V2_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return create_node_input_types(
+ environment=(["indoor", "outdoor"], {"default": "indoor"}),
+ max_depth=("FLOAT", {"min": 0, "max": 100, "default": 20.0, "step": 0.01})
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, environment, resolution=512, max_depth=20.0, **kwargs):
+ from custom_controlnet_aux.depth_anything_v2 import DepthAnythingV2Detector
+ filename = dict(indoor="depth_anything_v2_metric_hypersim_vitl.pth", outdoor="depth_anything_v2_metric_vkitti_vitl.pth")[environment]
+ model = DepthAnythingV2Detector.from_pretrained(filename=filename).to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, max_depth=max_depth)
+ del model
+ return (out, ) """
+
+NODE_CLASS_MAPPINGS = {
+ "DepthAnythingV2Preprocessor": Depth_Anything_V2_Preprocessor,
+ #"Metric_DepthAnythingV2Preprocessor": Depth_Anything_Metric_V2_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DepthAnythingV2Preprocessor": "Depth Anything V2 - Relative",
+ #"Metric_DepthAnythingV2Preprocessor": "Depth Anything V2 - Metric"
+}
\ No newline at end of file
diff --git a/node_wrappers/diffusion_edge.py b/node_wrappers/diffusion_edge.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d31c45fa865641f9780158bdc4e351d464ca92b
--- /dev/null
+++ b/node_wrappers/diffusion_edge.py
@@ -0,0 +1,41 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
+import comfy.model_management as model_management
+import sys
+
+def install_deps():
+ try:
+ import sklearn
+ except:
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', 'scikit-learn'])
+
+class DiffusionEdge_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ environment=INPUT.COMBO(["indoor", "urban", "natrual"]),
+ patch_batch_size=INPUT.INT(default=4, min=1, max=16),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, environment="indoor", patch_batch_size=4, resolution=512, **kwargs):
+ install_deps()
+ from custom_controlnet_aux.diffusion_edge import DiffusionEdgeDetector
+
+ model = DiffusionEdgeDetector \
+ .from_pretrained(filename = f"diffusion_edge_{environment}.pt") \
+ .to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, patch_batch_size=patch_batch_size)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "DiffusionEdge_Preprocessor": DiffusionEdge_Preprocessor,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DiffusionEdge_Preprocessor": "Diffusion Edge (batch size ↑ => speed ↑, VRAM ↑)",
+}
\ No newline at end of file
diff --git a/node_wrappers/dsine.py b/node_wrappers/dsine.py
new file mode 100644
index 0000000000000000000000000000000000000000..eadb39cc7c3f867750a3f29e0b97eea0d26c0278
--- /dev/null
+++ b/node_wrappers/dsine.py
@@ -0,0 +1,31 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class DSINE_Normal_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ fov=INPUT.FLOAT(max=365.0, default=60.0),
+ iterations=INPUT.INT(min=1, max=20, default=5),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, fov=60.0, iterations=5, resolution=512, **kwargs):
+ from custom_controlnet_aux.dsine import DsineDetector
+
+ model = DsineDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, fov=fov, iterations=iterations, resolution=resolution)
+ del model
+ return (out,)
+
+NODE_CLASS_MAPPINGS = {
+ "DSINE-NormalMapPreprocessor": DSINE_Normal_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DSINE-NormalMapPreprocessor": "DSINE Normal Map"
+}
\ No newline at end of file
diff --git a/node_wrappers/dwpose.py b/node_wrappers/dwpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0422588a42157be5996d2c8acd79c863bf9efc9
--- /dev/null
+++ b/node_wrappers/dwpose.py
@@ -0,0 +1,166 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+import numpy as np
+import warnings
+from ..src.custom_controlnet_aux.dwpose import DwposeDetector, AnimalposeDetector
+import os
+import json
+
+DWPOSE_MODEL_NAME = "yzd-v/DWPose"
+#Trigger startup caching for onnxruntime
+GPU_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"]
+def check_ort_gpu():
+ try:
+ import onnxruntime as ort
+ for provider in GPU_PROVIDERS:
+ if provider in ort.get_available_providers():
+ return True
+ return False
+ except:
+ return False
+
+if not os.environ.get("DWPOSE_ONNXRT_CHECKED"):
+ if check_ort_gpu():
+ print("DWPose: Onnxruntime with acceleration providers detected")
+ else:
+ warnings.warn("DWPose: Onnxruntime not found or doesn't come with acceleration providers, switch to OpenCV with CPU device. DWPose might run very slowly")
+ os.environ['AUX_ORT_PROVIDERS'] = ''
+ os.environ["DWPOSE_ONNXRT_CHECKED"] = '1'
+
+class DWPose_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ detect_hand=INPUT.COMBO(["enable", "disable"]),
+ detect_body=INPUT.COMBO(["enable", "disable"]),
+ detect_face=INPUT.COMBO(["enable", "disable"]),
+ resolution=INPUT.RESOLUTION(),
+ bbox_detector=INPUT.COMBO(
+ ["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
+ default="yolox_l.onnx"
+ ),
+ pose_estimator=INPUT.COMBO(
+ ["dw-ll_ucoco_384_bs5.torchscript.pt", "dw-ll_ucoco_384.onnx", "dw-ll_ucoco.onnx"],
+ default="dw-ll_ucoco_384_bs5.torchscript.pt"
+ ),
+ scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
+ )
+
+ RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
+ FUNCTION = "estimate_pose"
+
+ CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
+
+ def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="dw-ll_ucoco_384.onnx", scale_stick_for_xinsr_cn="disable", **kwargs):
+ if bbox_detector == "None":
+ yolo_repo = DWPOSE_MODEL_NAME
+ elif bbox_detector == "yolox_l.onnx":
+ yolo_repo = DWPOSE_MODEL_NAME
+ elif "yolox" in bbox_detector:
+ yolo_repo = "hr16/yolox-onnx"
+ elif "yolo_nas" in bbox_detector:
+ yolo_repo = "hr16/yolo-nas-fp16"
+ else:
+ raise NotImplementedError(f"Download mechanism for {bbox_detector}")
+
+ if pose_estimator == "dw-ll_ucoco_384.onnx":
+ pose_repo = DWPOSE_MODEL_NAME
+ elif pose_estimator.endswith(".onnx"):
+ pose_repo = "hr16/UnJIT-DWPose"
+ elif pose_estimator.endswith(".torchscript.pt"):
+ pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
+ else:
+ raise NotImplementedError(f"Download mechanism for {pose_estimator}")
+
+ model = DwposeDetector.from_pretrained(
+ pose_repo,
+ yolo_repo,
+ det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
+ torchscript_device=model_management.get_torch_device()
+ )
+ detect_hand = detect_hand == "enable"
+ detect_body = detect_body == "enable"
+ detect_face = detect_face == "enable"
+ scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
+ self.openpose_dicts = []
+ def func(image, **kwargs):
+ pose_img, openpose_dict = model(image, **kwargs)
+ self.openpose_dicts.append(openpose_dict)
+ return pose_img
+
+ out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, resolution=resolution, xinsr_stick_scaling=scale_stick_for_xinsr_cn)
+ del model
+ return {
+ 'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
+ "result": (out, self.openpose_dicts)
+ }
+
+class AnimalPose_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ bbox_detector = INPUT.COMBO(
+ ["None"] + ["yolox_l.torchscript.pt", "yolox_l.onnx", "yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"],
+ default="yolox_l.torchscript.pt"
+ ),
+ pose_estimator = INPUT.COMBO(
+ ["rtmpose-m_ap10k_256_bs5.torchscript.pt", "rtmpose-m_ap10k_256.onnx"],
+ default="rtmpose-m_ap10k_256_bs5.torchscript.pt"
+ ),
+ resolution = INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
+ FUNCTION = "estimate_pose"
+
+ CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
+
+ def estimate_pose(self, image, resolution=512, bbox_detector="yolox_l.onnx", pose_estimator="rtmpose-m_ap10k_256.onnx", **kwargs):
+ if bbox_detector == "None":
+ yolo_repo = DWPOSE_MODEL_NAME
+ elif bbox_detector == "yolox_l.onnx":
+ yolo_repo = DWPOSE_MODEL_NAME
+ elif "yolox" in bbox_detector:
+ yolo_repo = "hr16/yolox-onnx"
+ elif "yolo_nas" in bbox_detector:
+ yolo_repo = "hr16/yolo-nas-fp16"
+ else:
+ raise NotImplementedError(f"Download mechanism for {bbox_detector}")
+
+ if pose_estimator == "dw-ll_ucoco_384.onnx":
+ pose_repo = DWPOSE_MODEL_NAME
+ elif pose_estimator.endswith(".onnx"):
+ pose_repo = "hr16/UnJIT-DWPose"
+ elif pose_estimator.endswith(".torchscript.pt"):
+ pose_repo = "hr16/DWPose-TorchScript-BatchSize5"
+ else:
+ raise NotImplementedError(f"Download mechanism for {pose_estimator}")
+
+ model = AnimalposeDetector.from_pretrained(
+ pose_repo,
+ yolo_repo,
+ det_filename=(None if bbox_detector == "None" else bbox_detector), pose_filename=pose_estimator,
+ torchscript_device=model_management.get_torch_device()
+ )
+
+ self.openpose_dicts = []
+ def func(image, **kwargs):
+ pose_img, openpose_dict = model(image, **kwargs)
+ self.openpose_dicts.append(openpose_dict)
+ return pose_img
+
+ out = common_annotator_call(func, image, image_and_json=True, resolution=resolution)
+ del model
+ return {
+ 'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
+ "result": (out, self.openpose_dicts)
+ }
+
+NODE_CLASS_MAPPINGS = {
+ "DWPreprocessor": DWPose_Preprocessor,
+ "AnimalPosePreprocessor": AnimalPose_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "DWPreprocessor": "DWPose Estimator",
+ "AnimalPosePreprocessor": "AnimalPose Estimator (AP10K)"
+}
diff --git a/node_wrappers/hed.py b/node_wrappers/hed.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc81246e0b99cd2f5f8922cb73d18b85184dcc6d
--- /dev/null
+++ b/node_wrappers/hed.py
@@ -0,0 +1,53 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class HED_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ safe=INPUT.COMBO(["enable", "disable"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.hed import HEDdetector
+
+ model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, safe = kwargs["safe"] == "enable")
+ del model
+ return (out, )
+
+class Fake_Scribble_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ safe=INPUT.COMBO(["enable", "disable"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.hed import HEDdetector
+
+ model = HEDdetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, scribble=True, safe=kwargs["safe"]=="enable")
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "HEDPreprocessor": HED_Preprocessor,
+ "FakeScribblePreprocessor": Fake_Scribble_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "HEDPreprocessor": "HED Soft-Edge Lines",
+ "FakeScribblePreprocessor": "Fake Scribble Lines (aka scribble_hed)"
+}
\ No newline at end of file
diff --git a/node_wrappers/inpaint.py b/node_wrappers/inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93d87c5a1326d0c11c97305ea70e1161fd8f928
--- /dev/null
+++ b/node_wrappers/inpaint.py
@@ -0,0 +1,32 @@
+import torch
+from ..utils import INPUT
+
+class InpaintPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return dict(
+ required=dict(image=INPUT.IMAGE(), mask=INPUT.MASK()),
+ optional=dict(black_pixel_for_xinsir_cn=INPUT.BOOLEAN(False))
+ )
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "preprocess"
+
+ CATEGORY = "ControlNet Preprocessors/others"
+
+ def preprocess(self, image, mask, black_pixel_for_xinsir_cn=False):
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(image.shape[1], image.shape[2]), mode="bilinear")
+ mask = mask.movedim(1,-1).expand((-1,-1,-1,3))
+ image = image.clone()
+ if black_pixel_for_xinsir_cn:
+ masked_pixel = 0.0
+ else:
+ masked_pixel = -1.0
+ image[mask > 0.5] = masked_pixel
+ return (image,)
+
+NODE_CLASS_MAPPINGS = {
+ "InpaintPreprocessor": InpaintPreprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "InpaintPreprocessor": "Inpaint Preprocessor"
+}
diff --git a/node_wrappers/leres.py b/node_wrappers/leres.py
new file mode 100644
index 0000000000000000000000000000000000000000..040e463d46e2f382bb3b538fa2d79eebdddaae09
--- /dev/null
+++ b/node_wrappers/leres.py
@@ -0,0 +1,32 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class LERES_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ rm_nearest=INPUT.FLOAT(max=100.0),
+ rm_background=INPUT.FLOAT(max=100.0),
+ boost=INPUT.COMBO(["disable", "enable"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, rm_nearest=0, rm_background=0, resolution=512, boost="disable", **kwargs):
+ from custom_controlnet_aux.leres import LeresDetector
+
+ model = LeresDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, thr_a=rm_nearest, thr_b=rm_background, boost=boost == "enable")
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "LeReS-DepthMapPreprocessor": LERES_Depth_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "LeReS-DepthMapPreprocessor": "LeReS Depth Map (enable boost for leres++)"
+}
\ No newline at end of file
diff --git a/node_wrappers/lineart.py b/node_wrappers/lineart.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bbcd991fda5bfdd0cb99ebda321705eb81390c
--- /dev/null
+++ b/node_wrappers/lineart.py
@@ -0,0 +1,30 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class LineArt_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ coarse=INPUT.COMBO((["disable", "enable"])),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.lineart import LineartDetector
+
+ model = LineartDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, coarse = kwargs["coarse"] == "enable")
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "LineArtPreprocessor": LineArt_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "LineArtPreprocessor": "Realistic Lineart"
+}
\ No newline at end of file
diff --git a/node_wrappers/lineart_anime.py b/node_wrappers/lineart_anime.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e2267376d63d6e379743e19e0301ec3144dcdee
--- /dev/null
+++ b/node_wrappers/lineart_anime.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class AnimeLineArt_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.lineart_anime import LineartAnimeDetector
+
+ model = LineartAnimeDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "AnimeLineArtPreprocessor": AnimeLineArt_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "AnimeLineArtPreprocessor": "Anime Lineart"
+}
\ No newline at end of file
diff --git a/node_wrappers/lineart_standard.py b/node_wrappers/lineart_standard.py
new file mode 100644
index 0000000000000000000000000000000000000000..befeb94dc4ddd126f596ca4c0db6c5f2bf4e404f
--- /dev/null
+++ b/node_wrappers/lineart_standard.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class Lineart_Standard_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ guassian_sigma=INPUT.FLOAT(default=6.0, max=100.0),
+ intensity_threshold=INPUT.INT(default=8, max=16),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, guassian_sigma=6, intensity_threshold=8, resolution=512, **kwargs):
+ from custom_controlnet_aux.lineart_standard import LineartStandardDetector
+ return (common_annotator_call(LineartStandardDetector(), image, guassian_sigma=guassian_sigma, intensity_threshold=intensity_threshold, resolution=resolution), )
+
+NODE_CLASS_MAPPINGS = {
+ "LineartStandardPreprocessor": Lineart_Standard_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "LineartStandardPreprocessor": "Standard Lineart"
+}
\ No newline at end of file
diff --git a/node_wrappers/manga_line.py b/node_wrappers/manga_line.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b09f84e675a8c629e7f406509e58de048d0053b
--- /dev/null
+++ b/node_wrappers/manga_line.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class Manga2Anime_LineArt_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.manga_line import LineartMangaDetector
+
+ model = LineartMangaDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "Manga2Anime_LineArt_Preprocessor": Manga2Anime_LineArt_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "Manga2Anime_LineArt_Preprocessor": "Manga Lineart (aka lineart_anime_denoise)"
+}
\ No newline at end of file
diff --git a/node_wrappers/mediapipe_face.py b/node_wrappers/mediapipe_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9dc8a7b5e0f0830ba7e1c883d872ef6e9673d26
--- /dev/null
+++ b/node_wrappers/mediapipe_face.py
@@ -0,0 +1,39 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, run_script
+import comfy.model_management as model_management
+import os, sys
+import subprocess, threading
+
+def install_deps():
+ try:
+ import mediapipe
+ except ImportError:
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
+
+class Media_Pipe_Face_Mesh_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ max_faces=INPUT.INT(default=10, min=1, max=50), #Which image has more than 50 detectable faces?
+ min_confidence=INPUT.FLOAT(default=0.5, min=0.1),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "detect"
+
+ CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
+
+ def detect(self, image, max_faces=10, min_confidence=0.5, resolution=512):
+ #Ref: https://github.com/Fannovel16/comfy_controlnet_preprocessors/issues/70#issuecomment-1677967369
+ install_deps()
+ from custom_controlnet_aux.mediapipe_face import MediapipeFaceDetector
+ return (common_annotator_call(MediapipeFaceDetector(), image, max_faces=max_faces, min_confidence=min_confidence, resolution=resolution), )
+
+NODE_CLASS_MAPPINGS = {
+ "MediaPipe-FaceMeshPreprocessor": Media_Pipe_Face_Mesh_Preprocessor
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "MediaPipe-FaceMeshPreprocessor": "MediaPipe Face Mesh"
+}
\ No newline at end of file
diff --git a/node_wrappers/mesh_graphormer.py b/node_wrappers/mesh_graphormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cccc35ac012b52c0f7045a61be3b891dc569a99
--- /dev/null
+++ b/node_wrappers/mesh_graphormer.py
@@ -0,0 +1,158 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION, run_script
+import comfy.model_management as model_management
+import numpy as np
+import torch
+from einops import rearrange
+import os, sys
+import subprocess, threading
+import scipy.ndimage
+import cv2
+import torch.nn.functional as F
+
+def install_deps():
+ try:
+ import mediapipe
+ except ImportError:
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', 'mediapipe'])
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', '--upgrade', 'protobuf'])
+
+ try:
+ import trimesh
+ except ImportError:
+ run_script([sys.executable, '-s', '-m', 'pip', 'install', 'trimesh[easy]'])
+
+#Sauce: https://github.com/comfyanonymous/ComfyUI/blob/8c6493578b3dda233e9b9a953feeaf1e6ca434ad/comfy_extras/nodes_mask.py#L309
+def expand_mask(mask, expand, tapered_corners):
+ c = 0 if tapered_corners else 1
+ kernel = np.array([[c, 1, c],
+ [1, 1, 1],
+ [c, 1, c]])
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
+ out = []
+ for m in mask:
+ output = m.numpy()
+ for _ in range(abs(expand)):
+ if expand < 0:
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
+ else:
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
+ output = torch.from_numpy(output)
+ out.append(output)
+ return torch.stack(out, dim=0)
+
+class Mesh_Graphormer_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ mask_bbox_padding=("INT", {"default": 30, "min": 0, "max": 100}),
+ resolution=INPUT.RESOLUTION(),
+ mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
+ mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
+ rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
+ detect_thr=INPUT.FLOAT(default=0.6, min=0.1),
+ presence_thr=INPUT.FLOAT(default=0.6, min=0.1)
+ )
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, mask_bbox_padding=30, mask_type="based_on_depth", mask_expand=5, resolution=512, rand_seed=88, detect_thr=0.6, presence_thr=0.6, **kwargs):
+ install_deps()
+ from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
+ model = kwargs["model"] if "model" in kwargs \
+ else MeshGraphormerDetector.from_pretrained(detect_thr=detect_thr, presence_thr=presence_thr).to(model_management.get_torch_device())
+
+ depth_map_list = []
+ mask_list = []
+ for single_image in image:
+ np_image = np.asarray(single_image.cpu() * 255., dtype=np.uint8)
+ depth_map, mask, info = model(np_image, output_type="np", detect_resolution=resolution, mask_bbox_padding=mask_bbox_padding, seed=rand_seed)
+ if mask_type == "based_on_depth":
+ H, W = mask.shape[:2]
+ mask = cv2.resize(depth_map.copy(), (W, H))
+ mask[mask > 0] = 255
+
+ elif mask_type == "tight_bboxes":
+ mask = np.zeros_like(mask)
+ hand_bboxes = (info or {}).get("abs_boxes") or []
+ for hand_bbox in hand_bboxes:
+ x_min, x_max, y_min, y_max = hand_bbox
+ mask[y_min:y_max+1, x_min:x_max+1, :] = 255 #HWC
+
+ mask = mask[:, :, :1]
+ depth_map_list.append(torch.from_numpy(depth_map.astype(np.float32) / 255.0))
+ mask_list.append(torch.from_numpy(mask.astype(np.float32) / 255.0))
+ depth_maps, masks = torch.stack(depth_map_list, dim=0), rearrange(torch.stack(mask_list, dim=0), "n h w 1 -> n 1 h w")
+ return depth_maps, expand_mask(masks, mask_expand, tapered_corners=True)
+
+def normalize_size_base_64(w, h):
+ short_side = min(w, h)
+ remainder = short_side % 64
+ return short_side - remainder + (64 if remainder > 0 else 0)
+
+class Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ types = define_preprocessor_inputs(
+ # Impact pack
+ bbox_threshold=INPUT.FLOAT(default=0.5, min=0.1),
+ bbox_dilation=INPUT.INT(default=10, min=-512, max=512),
+ bbox_crop_factor=INPUT.FLOAT(default=3.0, min=1.0, max=10.0),
+ drop_size=INPUT.INT(default=10, min=1, max=MAX_RESOLUTION),
+ # Mesh Graphormer
+ mask_bbox_padding=INPUT.INT(default=30, min=0, max=100),
+ mask_type=INPUT.COMBO(["based_on_depth", "tight_bboxes", "original"]),
+ mask_expand=INPUT.INT(default=5, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
+ rand_seed=INPUT.INT(default=88, min=0, max=0xffffffffffffffff),
+ resolution=INPUT.RESOLUTION()
+ )
+ types["required"]["bbox_detector"] = ("BBOX_DETECTOR", )
+ return types
+
+ RETURN_TYPES = ("IMAGE", "MASK")
+ RETURN_NAMES = ("IMAGE", "INPAINTING_MASK")
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, bbox_detector, bbox_threshold=0.5, bbox_dilation=10, bbox_crop_factor=3.0, drop_size=10, resolution=512, **mesh_graphormer_kwargs):
+ install_deps()
+ from custom_controlnet_aux.mesh_graphormer import MeshGraphormerDetector
+ mesh_graphormer_node = Mesh_Graphormer_Depth_Map_Preprocessor()
+ model = MeshGraphormerDetector.from_pretrained(detect_thr=0.6, presence_thr=0.6).to(model_management.get_torch_device())
+ mesh_graphormer_kwargs["model"] = model
+
+ frames = image
+ depth_maps, masks = [], []
+ for idx in range(len(frames)):
+ frame = frames[idx:idx+1,...] #Impact Pack's BBOX_DETECTOR only supports single batch image
+ bbox_detector.setAux('face') # make default prompt as 'face' if empty prompt for CLIPSeg
+ _, segs = bbox_detector.detect(frame, bbox_threshold, bbox_dilation, bbox_crop_factor, drop_size)
+ bbox_detector.setAux(None)
+
+ n, h, w, _ = frame.shape
+ depth_map, mask = torch.zeros_like(frame), torch.zeros(n, 1, h, w)
+ for i, seg in enumerate(segs):
+ x1, y1, x2, y2 = seg.crop_region
+ cropped_image = frame[:, y1:y2, x1:x2, :] # Never use seg.cropped_image to handle overlapping area
+ mesh_graphormer_kwargs["resolution"] = 0 #Disable resizing
+ sub_depth_map, sub_mask = mesh_graphormer_node.execute(cropped_image, **mesh_graphormer_kwargs)
+ depth_map[:, y1:y2, x1:x2, :] = sub_depth_map
+ mask[:, :, y1:y2, x1:x2] = sub_mask
+
+ depth_maps.append(depth_map)
+ masks.append(mask)
+
+ return (torch.cat(depth_maps), torch.cat(masks))
+
+NODE_CLASS_MAPPINGS = {
+ "MeshGraphormer-DepthMapPreprocessor": Mesh_Graphormer_Depth_Map_Preprocessor,
+ "MeshGraphormer+ImpactDetector-DepthMapPreprocessor": Mesh_Graphormer_With_ImpactDetector_Depth_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "MeshGraphormer-DepthMapPreprocessor": "MeshGraphormer Hand Refiner",
+ "MeshGraphormer+ImpactDetector-DepthMapPreprocessor": "MeshGraphormer Hand Refiner With External Detector"
+}
\ No newline at end of file
diff --git a/node_wrappers/metric3d.py b/node_wrappers/metric3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be81aea73cabe185c98649bfec6728fc24a625d
--- /dev/null
+++ b/node_wrappers/metric3d.py
@@ -0,0 +1,62 @@
+import os
+# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
+os.environ['NPU_DEVICE_COUNT'] = '0'
+os.environ['MMCV_WITH_OPS'] = '0'
+
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
+import comfy.model_management as model_management
+
+class Metric3D_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
+ fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
+ fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
+ from custom_controlnet_aux.metric3d import Metric3DDetector
+ model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
+ cb = lambda image, **kwargs: model(image, **kwargs)[0]
+ out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
+ del model
+ return (out, )
+
+class Metric3D_Normal_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ backbone=INPUT.COMBO(["vit-small", "vit-large", "vit-giant2"]),
+ fx=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
+ fy=INPUT.INT(default=1000, min=1, max=MAX_RESOLUTION),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, backbone="vit-small", fx=1000, fy=1000, resolution=512):
+ from custom_controlnet_aux.metric3d import Metric3DDetector
+ model = Metric3DDetector.from_pretrained(filename=f"metric_depth_{backbone.replace('-', '_')}_800k.pth").to(model_management.get_torch_device())
+ cb = lambda image, **kwargs: model(image, **kwargs)[1]
+ out = common_annotator_call(cb, image, resolution=resolution, fx=fx, fy=fy, depth_and_normal=True)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "Metric3D-DepthMapPreprocessor": Metric3D_Depth_Map_Preprocessor,
+ "Metric3D-NormalMapPreprocessor": Metric3D_Normal_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "Metric3D-DepthMapPreprocessor": "Metric3D Depth Map",
+ "Metric3D-NormalMapPreprocessor": "Metric3D Normal Map"
+}
diff --git a/node_wrappers/midas.py b/node_wrappers/midas.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c06853867859bd36b282da47b0f057bf499721
--- /dev/null
+++ b/node_wrappers/midas.py
@@ -0,0 +1,59 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+import numpy as np
+
+class MIDAS_Normal_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
+ bg_threshold=INPUT.FLOAT(default=0.1),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
+ from custom_controlnet_aux.midas import MidasDetector
+
+ model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
+ #Dirty hack :))
+ cb = lambda image, **kargs: model(image, **kargs)[1]
+ out = common_annotator_call(cb, image, resolution=resolution, a=a, bg_th=bg_threshold, depth_and_normal=True)
+ del model
+ return (out, )
+
+class MIDAS_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ a=INPUT.FLOAT(default=np.pi * 2.0, min=0.0, max=np.pi * 5.0),
+ bg_threshold=INPUT.FLOAT(default=0.1),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, a=np.pi * 2.0, bg_threshold=0.1, resolution=512, **kwargs):
+ from custom_controlnet_aux.midas import MidasDetector
+
+ # Ref: https://github.com/lllyasviel/ControlNet/blob/main/gradio_depth2image.py
+ model = MidasDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, a=a, bg_th=bg_threshold)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "MiDaS-NormalMapPreprocessor": MIDAS_Normal_Map_Preprocessor,
+ "MiDaS-DepthMapPreprocessor": MIDAS_Depth_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "MiDaS-NormalMapPreprocessor": "MiDaS Normal Map",
+ "MiDaS-DepthMapPreprocessor": "MiDaS Depth Map"
+}
\ No newline at end of file
diff --git a/node_wrappers/mlsd.py b/node_wrappers/mlsd.py
new file mode 100644
index 0000000000000000000000000000000000000000..30bae1c61f04ac10b401bbc7ffeb1d2fdb94c60d
--- /dev/null
+++ b/node_wrappers/mlsd.py
@@ -0,0 +1,31 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+import numpy as np
+
+class MLSD_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ score_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=2.0),
+ dist_threshold=INPUT.FLOAT(default=0.1, min=0.01, max=20.0),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, score_threshold, dist_threshold, resolution=512, **kwargs):
+ from custom_controlnet_aux.mlsd import MLSDdetector
+
+ model = MLSDdetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, thr_v=score_threshold, thr_d=dist_threshold)
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "M-LSDPreprocessor": MLSD_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "M-LSDPreprocessor": "M-LSD Lines"
+}
\ No newline at end of file
diff --git a/node_wrappers/normalbae.py b/node_wrappers/normalbae.py
new file mode 100644
index 0000000000000000000000000000000000000000..af013e15373a146422715daf77537a9847042e89
--- /dev/null
+++ b/node_wrappers/normalbae.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class BAE_Normal_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.normalbae import NormalBaeDetector
+
+ model = NormalBaeDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out,)
+
+NODE_CLASS_MAPPINGS = {
+ "BAE-NormalMapPreprocessor": BAE_Normal_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "BAE-NormalMapPreprocessor": "BAE Normal Map"
+}
\ No newline at end of file
diff --git a/node_wrappers/oneformer.py b/node_wrappers/oneformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c7cd3e7a5aa5d820936a30045c1a4518dff5cdc
--- /dev/null
+++ b/node_wrappers/oneformer.py
@@ -0,0 +1,50 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class OneFormer_COCO_SemSegPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "semantic_segmentate"
+
+ CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
+
+ def semantic_segmentate(self, image, resolution=512):
+ from custom_controlnet_aux.oneformer import OneformerSegmentor
+
+ model = OneformerSegmentor.from_pretrained(filename="150_16_swin_l_oneformer_coco_100ep.pth")
+ model = model.to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out,)
+
+class OneFormer_ADE20K_SemSegPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "semantic_segmentate"
+
+ CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
+
+ def semantic_segmentate(self, image, resolution=512):
+ from custom_controlnet_aux.oneformer import OneformerSegmentor
+
+ model = OneformerSegmentor.from_pretrained(filename="250_16_swin_l_oneformer_ade20k_160k.pth")
+ model = model.to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out,)
+
+NODE_CLASS_MAPPINGS = {
+ "OneFormer-COCO-SemSegPreprocessor": OneFormer_COCO_SemSegPreprocessor,
+ "OneFormer-ADE20K-SemSegPreprocessor": OneFormer_ADE20K_SemSegPreprocessor
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "OneFormer-COCO-SemSegPreprocessor": "OneFormer COCO Segmentor",
+ "OneFormer-ADE20K-SemSegPreprocessor": "OneFormer ADE20K Segmentor"
+}
\ No newline at end of file
diff --git a/node_wrappers/openpose.py b/node_wrappers/openpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..c579e6cdde1add9d7d2f0ccab72bd67bafc35eea
--- /dev/null
+++ b/node_wrappers/openpose.py
@@ -0,0 +1,48 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+import json
+
+class OpenPose_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ detect_hand=INPUT.COMBO(["enable", "disable"]),
+ detect_body=INPUT.COMBO(["enable", "disable"]),
+ detect_face=INPUT.COMBO(["enable", "disable"]),
+ resolution=INPUT.RESOLUTION(),
+ scale_stick_for_xinsr_cn=INPUT.COMBO(["disable", "enable"])
+ )
+
+ RETURN_TYPES = ("IMAGE", "POSE_KEYPOINT")
+ FUNCTION = "estimate_pose"
+
+ CATEGORY = "ControlNet Preprocessors/Faces and Poses Estimators"
+
+ def estimate_pose(self, image, detect_hand="enable", detect_body="enable", detect_face="enable", scale_stick_for_xinsr_cn="disable", resolution=512, **kwargs):
+ from custom_controlnet_aux.open_pose import OpenposeDetector
+
+ detect_hand = detect_hand == "enable"
+ detect_body = detect_body == "enable"
+ detect_face = detect_face == "enable"
+ scale_stick_for_xinsr_cn = scale_stick_for_xinsr_cn == "enable"
+
+ model = OpenposeDetector.from_pretrained().to(model_management.get_torch_device())
+ self.openpose_dicts = []
+ def func(image, **kwargs):
+ pose_img, openpose_dict = model(image, **kwargs)
+ self.openpose_dicts.append(openpose_dict)
+ return pose_img
+
+ out = common_annotator_call(func, image, include_hand=detect_hand, include_face=detect_face, include_body=detect_body, image_and_json=True, xinsr_stick_scaling=scale_stick_for_xinsr_cn, resolution=resolution)
+ del model
+ return {
+ 'ui': { "openpose_json": [json.dumps(self.openpose_dicts, indent=4)] },
+ "result": (out, self.openpose_dicts)
+ }
+
+NODE_CLASS_MAPPINGS = {
+ "OpenposePreprocessor": OpenPose_Preprocessor,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "OpenposePreprocessor": "OpenPose Pose",
+}
\ No newline at end of file
diff --git a/node_wrappers/pidinet.py b/node_wrappers/pidinet.py
new file mode 100644
index 0000000000000000000000000000000000000000..92f22f92c6d8313cbe2fbe88c8340b4d68742ad3
--- /dev/null
+++ b/node_wrappers/pidinet.py
@@ -0,0 +1,30 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class PIDINET_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ safe=INPUT.COMBO(["enable", "disable"]),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, safe, resolution=512, **kwargs):
+ from custom_controlnet_aux.pidi import PidiNetDetector
+
+ model = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, safe = safe == "enable")
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "PiDiNetPreprocessor": PIDINET_Preprocessor,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PiDiNetPreprocessor": "PiDiNet Soft-Edge Lines"
+}
\ No newline at end of file
diff --git a/node_wrappers/pose_keypoint_postprocess.py b/node_wrappers/pose_keypoint_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..11f1e82d9010782a476ecc60b2050906b3ae7b3f
--- /dev/null
+++ b/node_wrappers/pose_keypoint_postprocess.py
@@ -0,0 +1,340 @@
+import folder_paths
+import json
+import os
+import numpy as np
+import cv2
+from PIL import ImageColor
+from einops import rearrange
+import torch
+import itertools
+
+from ..src.custom_controlnet_aux.dwpose import draw_poses, draw_animalposes, decode_json_as_poses
+
+
+"""
+Format of POSE_KEYPOINT (AP10K keypoints):
+[{
+ "version": "ap10k",
+ "animals": [
+ [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
+ [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]],
+ ...
+ ],
+ "canvas_height": 512,
+ "canvas_width": 768
+},...]
+Format of POSE_KEYPOINT (OpenPose keypoints):
+[{
+ "people": [
+ {
+ 'pose_keypoints_2d': [[x1, y1, 1], [x2, y2, 1],..., [x17, y17, 1]]
+ "face_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x68, y68, 1]],
+ "hand_left_keypoints_2d": [[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
+ "hand_right_keypoints_2d":[[x1, y1, 1], [x2, y2, 1],..., [x21, y21, 1]],
+ }
+ ],
+ "canvas_height": canvas_height,
+ "canvas_width": canvas_width,
+},...]
+"""
+
+class SavePoseKpsAsJsonFile:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "pose_kps": ("POSE_KEYPOINT",),
+ "filename_prefix": ("STRING", {"default": "PoseKeypoint"})
+ }
+ }
+ RETURN_TYPES = ()
+ FUNCTION = "save_pose_kps"
+ OUTPUT_NODE = True
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
+ def save_pose_kps(self, pose_kps, filename_prefix):
+ filename_prefix += self.prefix_append
+ full_output_folder, filename, counter, subfolder, filename_prefix = \
+ folder_paths.get_save_image_path(filename_prefix, self.output_dir, pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"])
+ file = f"{filename}_{counter:05}.json"
+ with open(os.path.join(full_output_folder, file), 'w') as f:
+ json.dump(pose_kps , f)
+ return {}
+
+#COCO-Wholebody doesn't have eyebrows as it inherits 68 keypoints format
+#Perhaps eyebrows can be estimated tho
+FACIAL_PARTS = ["skin", "left_eye", "right_eye", "nose", "upper_lip", "inner_mouth", "lower_lip"]
+LAPA_COLORS = dict(
+ skin="rgb(0, 153, 255)",
+ left_eye="rgb(0, 204, 153)",
+ right_eye="rgb(255, 153, 0)",
+ nose="rgb(255, 102, 255)",
+ upper_lip="rgb(102, 0, 51)",
+ inner_mouth="rgb(255, 204, 255)",
+ lower_lip="rgb(255, 0, 102)"
+)
+
+#One-based index
+def kps_idxs(start, end):
+ step = -1 if start > end else 1
+ return list(range(start-1, end+1-1, step))
+
+#Source: https://www.researchgate.net/profile/Fabrizio-Falchi/publication/338048224/figure/fig1/AS:837860722741255@1576772971540/68-facial-landmarks.jpg
+FACIAL_PART_RANGES = dict(
+ skin=kps_idxs(1, 17) + kps_idxs(27, 18),
+ nose=kps_idxs(28, 36),
+ left_eye=kps_idxs(37, 42),
+ right_eye=kps_idxs(43, 48),
+ upper_lip=kps_idxs(49, 55) + kps_idxs(65, 61),
+ lower_lip=kps_idxs(61, 68),
+ inner_mouth=kps_idxs(61, 65) + kps_idxs(55, 49)
+)
+
+def is_normalized(keypoints) -> bool:
+ point_normalized = [
+ 0 <= np.abs(k[0]) <= 1 and 0 <= np.abs(k[1]) <= 1
+ for k in keypoints
+ if k is not None
+ ]
+ if not point_normalized:
+ return False
+ return np.all(point_normalized)
+
+class FacialPartColoringFromPoseKps:
+ @classmethod
+ def INPUT_TYPES(s):
+ input_types = {
+ "required": {"pose_kps": ("POSE_KEYPOINT",), "mode": (["point", "polygon"], {"default": "polygon"})}
+ }
+ for facial_part in FACIAL_PARTS:
+ input_types["required"][facial_part] = ("STRING", {"default": LAPA_COLORS[facial_part], "multiline": False})
+ return input_types
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "colorize"
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
+ def colorize(self, pose_kps, mode, **facial_part_colors):
+ pose_frames = pose_kps
+ np_frames = [self.draw_kps(pose_frame, mode, **facial_part_colors) for pose_frame in pose_frames]
+ np_frames = np.stack(np_frames, axis=0)
+ return (torch.from_numpy(np_frames).float() / 255.,)
+
+ def draw_kps(self, pose_frame, mode, **facial_part_colors):
+ width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
+ for person, part_name in itertools.product(pose_frame["people"], FACIAL_PARTS):
+ n = len(person["face_keypoints_2d"]) // 3
+ facial_kps = rearrange(np.array(person["face_keypoints_2d"]), "(n c) -> n c", n=n, c=3)[:, :2]
+ if is_normalized(facial_kps):
+ facial_kps *= (width, height)
+ facial_kps = facial_kps.astype(np.int32)
+ part_color = ImageColor.getrgb(facial_part_colors[part_name])[:3]
+ part_contours = facial_kps[FACIAL_PART_RANGES[part_name], :]
+ if mode == "point":
+ for pt in part_contours:
+ cv2.circle(canvas, pt, radius=2, color=part_color, thickness=-1)
+ else:
+ cv2.fillPoly(canvas, pts=[part_contours], color=part_color)
+ return canvas
+
+# https://raw.githubusercontent.com/CMU-Perceptual-Computing-Lab/openpose/master/.github/media/keypoints_pose_18.png
+BODY_PART_INDEXES = {
+ "Head": (16, 14, 0, 15, 17),
+ "Neck": (0, 1),
+ "Shoulder": (2, 5),
+ "Torso": (2, 5, 8, 11),
+ "RArm": (2, 3),
+ "RForearm": (3, 4),
+ "LArm": (5, 6),
+ "LForearm": (6, 7),
+ "RThigh": (8, 9),
+ "RLeg": (9, 10),
+ "LThigh": (11, 12),
+ "LLeg": (12, 13)
+}
+BODY_PART_DEFAULT_W_H = {
+ "Head": "256, 256",
+ "Neck": "100, 100",
+ "Shoulder": '',
+ "Torso": "350, 450",
+ "RArm": "128, 256",
+ "RForearm": "128, 256",
+ "LArm": "128, 256",
+ "LForearm": "128, 256",
+ "RThigh": "128, 256",
+ "RLeg": "128, 256",
+ "LThigh": "128, 256",
+ "LLeg": "128, 256"
+}
+
+class SinglePersonProcess:
+ @classmethod
+ def sort_and_get_max_people(s, pose_kps):
+ for idx in range(len(pose_kps)):
+ pose_kps[idx]["people"] = sorted(pose_kps[idx]["people"], key=lambda person:person["pose_keypoints_2d"][0])
+ return pose_kps, max(len(frame["people"]) for frame in pose_kps)
+
+ def __init__(self, pose_kps, person_idx=0) -> None:
+ self.width, self.height = pose_kps[0]["canvas_width"], pose_kps[0]["canvas_height"]
+ self.poses = [
+ self.normalize(pose_frame["people"][person_idx]["pose_keypoints_2d"])
+ if person_idx < len(pose_frame["people"])
+ else None
+ for pose_frame in pose_kps
+ ]
+
+ def normalize(self, pose_kps_2d):
+ n = len(pose_kps_2d) // 3
+ pose_kps_2d = rearrange(np.array(pose_kps_2d), "(n c) -> n c", n=n, c=3)
+ pose_kps_2d[np.argwhere(pose_kps_2d[:,2]==0), :] = np.iinfo(np.int32).max // 2 #Safe large value
+ pose_kps_2d = pose_kps_2d[:, :2]
+ if is_normalized(pose_kps_2d):
+ pose_kps_2d *= (self.width, self.height)
+ return pose_kps_2d
+
+ def get_xyxy_bboxes(self, part_name, bbox_size=(128, 256)):
+ width, height = bbox_size
+ xyxy_bboxes = {}
+ for idx, pose in enumerate(self.poses):
+ if pose is None:
+ xyxy_bboxes[idx] = (np.iinfo(np.int32).max // 2,) * 4
+ continue
+ pts = pose[BODY_PART_INDEXES[part_name], :]
+
+ #top_left = np.min(pts[:,0]), np.min(pts[:,1])
+ #bottom_right = np.max(pts[:,0]), np.max(pts[:,1])
+ #pad_width = np.maximum(width - (bottom_right[0]-top_left[0]), 0) / 2
+ #pad_height = np.maximum(height - (bottom_right[1]-top_left[1]), 0) / 2
+ #xyxy_bboxes.append((
+ # top_left[0] - pad_width, top_left[1] - pad_height,
+ # bottom_right[0] + pad_width, bottom_right[1] + pad_height,
+ #))
+
+ x_mid, y_mid = np.mean(pts[:, 0]), np.mean(pts[:, 1])
+ xyxy_bboxes[idx] = (
+ x_mid - width/2, y_mid - height/2,
+ x_mid + width/2, y_mid + height/2
+ )
+ return xyxy_bboxes
+
+class UpperBodyTrackingFromPoseKps:
+ PART_NAMES = ["Head", "Neck", "Shoulder", "Torso", "RArm", "RForearm", "LArm", "LForearm"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "pose_kps": ("POSE_KEYPOINT",),
+ "id_include": ("STRING", {"default": '', "multiline": False}),
+ **{part_name + "_width_height": ("STRING", {"default": BODY_PART_DEFAULT_W_H[part_name], "multiline": False}) for part_name in s.PART_NAMES}
+ }
+ }
+
+ RETURN_TYPES = ("TRACKING", "STRING")
+ RETURN_NAMES = ("tracking", "prompt")
+ FUNCTION = "convert"
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
+
+ def convert(self, pose_kps, id_include, **parts_width_height):
+ parts_width_height = {part_name.replace("_width_height", ''): value for part_name, value in parts_width_height.items()}
+ enabled_part_names = [part_name for part_name in self.PART_NAMES if len(parts_width_height[part_name].strip())]
+ tracked = {part_name: {} for part_name in enabled_part_names}
+ id_include = id_include.strip()
+ id_include = list(map(int, id_include.split(','))) if len(id_include) else []
+ prompt_string = ''
+ pose_kps, max_people = SinglePersonProcess.sort_and_get_max_people(pose_kps)
+
+ for person_idx in range(max_people):
+ if len(id_include) and person_idx not in id_include:
+ continue
+ processor = SinglePersonProcess(pose_kps, person_idx)
+ for part_name in enabled_part_names:
+ bbox_size = tuple(map(int, parts_width_height[part_name].split(',')))
+ part_bboxes = processor.get_xyxy_bboxes(part_name, bbox_size)
+ id_coordinates = {idx: part_bbox+(processor.width, processor.height) for idx, part_bbox in part_bboxes.items()}
+ tracked[part_name][person_idx] = id_coordinates
+
+ for class_name, class_data in tracked.items():
+ for class_id in class_data.keys():
+ class_id_str = str(class_id)
+ # Use the incoming prompt for each class name and ID
+ _class_name = class_name.replace('L', '').replace('R', '').lower()
+ prompt_string += f'"{class_id_str}.{class_name}": "({_class_name})",\n'
+
+ return (tracked, prompt_string)
+
+
+def numpy2torch(np_image: np.ndarray) -> torch.Tensor:
+ """ [H, W, C] => [B=1, H, W, C]"""
+ return torch.from_numpy(np_image.astype(np.float32) / 255).unsqueeze(0)
+
+
+class RenderPeopleKps:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "kps": ("POSE_KEYPOINT",),
+ "render_body": ("BOOLEAN", {"default": True}),
+ "render_hand": ("BOOLEAN", {"default": True}),
+ "render_face": ("BOOLEAN", {"default": True}),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "render"
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
+
+ def render(self, kps, render_body, render_hand, render_face) -> tuple[np.ndarray]:
+ if isinstance(kps, list):
+ kps = kps[0]
+
+ poses, _, height, width = decode_json_as_poses(kps)
+ np_image = draw_poses(
+ poses,
+ height,
+ width,
+ render_body,
+ render_hand,
+ render_face,
+ )
+ return (numpy2torch(np_image),)
+
+class RenderAnimalKps:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "kps": ("POSE_KEYPOINT",),
+ }
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "render"
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
+
+ def render(self, kps) -> tuple[np.ndarray]:
+ if isinstance(kps, list):
+ kps = kps[0]
+
+ _, poses, height, width = decode_json_as_poses(kps)
+ np_image = draw_animalposes(poses, height, width)
+ return (numpy2torch(np_image),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "SavePoseKpsAsJsonFile": SavePoseKpsAsJsonFile,
+ "FacialPartColoringFromPoseKps": FacialPartColoringFromPoseKps,
+ "UpperBodyTrackingFromPoseKps": UpperBodyTrackingFromPoseKps,
+ "RenderPeopleKps": RenderPeopleKps,
+ "RenderAnimalKps": RenderAnimalKps,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "SavePoseKpsAsJsonFile": "Save Pose Keypoints",
+ "FacialPartColoringFromPoseKps": "Colorize Facial Parts from PoseKPS",
+ "UpperBodyTrackingFromPoseKps": "Upper Body Tracking From PoseKps (InstanceDiffusion)",
+ "RenderPeopleKps": "Render Pose JSON (Human)",
+ "RenderAnimalKps": "Render Pose JSON (Animal)",
+}
diff --git a/node_wrappers/pyracanny.py b/node_wrappers/pyracanny.py
new file mode 100644
index 0000000000000000000000000000000000000000..996c0b64cd847131d327fb1f94c218782bce94fc
--- /dev/null
+++ b/node_wrappers/pyracanny.py
@@ -0,0 +1,30 @@
+from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
+import comfy.model_management as model_management
+
+class PyraCanny_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ low_threshold=INPUT.INT(default=64, max=255),
+ high_threshold=INPUT.INT(default=128, max=255),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, low_threshold=64, high_threshold=128, resolution=512, **kwargs):
+ from custom_controlnet_aux.pyracanny import PyraCannyDetector
+
+ return (common_annotator_call(PyraCannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )
+
+
+
+NODE_CLASS_MAPPINGS = {
+ "PyraCannyPreprocessor": PyraCanny_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PyraCannyPreprocessor": "PyraCanny"
+}
\ No newline at end of file
diff --git a/node_wrappers/recolor.py b/node_wrappers/recolor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c33513c2dbc8f9ba8f7ad43412edf162ec5b470
--- /dev/null
+++ b/node_wrappers/recolor.py
@@ -0,0 +1,46 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+
+class ImageLuminanceDetector:
+ @classmethod
+ def INPUT_TYPES(s):
+ #https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
+ return define_preprocessor_inputs(
+ gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Recolor"
+
+ def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
+ from custom_controlnet_aux.recolor import Recolorizer
+ return (common_annotator_call(Recolorizer(), image, mode="luminance", gamma_correction=gamma_correction , resolution=resolution), )
+
+class ImageIntensityDetector:
+ @classmethod
+ def INPUT_TYPES(s):
+ #https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L1229
+ return define_preprocessor_inputs(
+ gamma_correction=INPUT.FLOAT(default=1.0, min=0.1, max=2.0),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Recolor"
+
+ def execute(self, image, gamma_correction=1.0, resolution=512, **kwargs):
+ from custom_controlnet_aux.recolor import Recolorizer
+ return (common_annotator_call(Recolorizer(), image, mode="intensity", gamma_correction=gamma_correction , resolution=resolution), )
+
+NODE_CLASS_MAPPINGS = {
+ "ImageLuminanceDetector": ImageLuminanceDetector,
+ "ImageIntensityDetector": ImageIntensityDetector
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "ImageLuminanceDetector": "Image Luminance",
+ "ImageIntensityDetector": "Image Intensity"
+}
\ No newline at end of file
diff --git a/node_wrappers/scribble.py b/node_wrappers/scribble.py
new file mode 100644
index 0000000000000000000000000000000000000000..2945502fa7f5d28fe85ce36c78a4936eba2d2d49
--- /dev/null
+++ b/node_wrappers/scribble.py
@@ -0,0 +1,74 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, nms
+import comfy.model_management as model_management
+import cv2
+
+class Scribble_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.scribble import ScribbleDetector
+
+ model = ScribbleDetector()
+ return (common_annotator_call(model, image, resolution=resolution), )
+
+class Scribble_XDoG_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ threshold=INPUT.INT(default=32, min=1, max=64),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, threshold=32, resolution=512, **kwargs):
+ from custom_controlnet_aux.scribble import ScribbleXDog_Detector
+
+ model = ScribbleXDog_Detector()
+ return (common_annotator_call(model, image, resolution=resolution, thr_a=threshold), )
+
+class Scribble_PiDiNet_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ safe=(["enable", "disable"],),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, safe="enable", resolution=512):
+ def model(img, **kwargs):
+ from custom_controlnet_aux.pidi import PidiNetDetector
+ pidinet = PidiNetDetector.from_pretrained().to(model_management.get_torch_device())
+ result = pidinet(img, scribble=True, **kwargs)
+ result = nms(result, 127, 3.0)
+ result = cv2.GaussianBlur(result, (0, 0), 3.0)
+ result[result > 4] = 255
+ result[result < 255] = 0
+ return result
+ return (common_annotator_call(model, image, resolution=resolution, safe=safe=="enable"),)
+
+NODE_CLASS_MAPPINGS = {
+ "ScribblePreprocessor": Scribble_Preprocessor,
+ "Scribble_XDoG_Preprocessor": Scribble_XDoG_Preprocessor,
+ "Scribble_PiDiNet_Preprocessor": Scribble_PiDiNet_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "ScribblePreprocessor": "Scribble Lines",
+ "Scribble_XDoG_Preprocessor": "Scribble XDoG Lines",
+ "Scribble_PiDiNet_Preprocessor": "Scribble PiDiNet Lines"
+}
diff --git a/node_wrappers/segment_anything.py b/node_wrappers/segment_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91412a5bd222f500a2dd51e8a1abbc640f2dc56
--- /dev/null
+++ b/node_wrappers/segment_anything.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class SAM_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/others"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.sam import SamDetector
+
+ mobile_sam = SamDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(mobile_sam, image, resolution=resolution)
+ del mobile_sam
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "SAMPreprocessor": SAM_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "SAMPreprocessor": "SAM Segmentor"
+}
\ No newline at end of file
diff --git a/node_wrappers/shuffle.py b/node_wrappers/shuffle.py
new file mode 100644
index 0000000000000000000000000000000000000000..6949f6df6b76fd9048d5174d9c0fe330f999e73d
--- /dev/null
+++ b/node_wrappers/shuffle.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION
+import comfy.model_management as model_management
+
+class Shuffle_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ resolution=INPUT.RESOLUTION(),
+ seed=INPUT.SEED()
+ )
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "preprocess"
+
+ CATEGORY = "ControlNet Preprocessors/T2IAdapter-only"
+
+ def preprocess(self, image, resolution=512, seed=0):
+ from custom_controlnet_aux.shuffle import ContentShuffleDetector
+
+ return (common_annotator_call(ContentShuffleDetector(), image, resolution=resolution, seed=seed), )
+
+NODE_CLASS_MAPPINGS = {
+ "ShufflePreprocessor": Shuffle_Preprocessor
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "ShufflePreprocessor": "Content Shuffle"
+}
\ No newline at end of file
diff --git a/node_wrappers/teed.py b/node_wrappers/teed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3b533ecdbdf67805b5220b6bf010705e084aeae
--- /dev/null
+++ b/node_wrappers/teed.py
@@ -0,0 +1,30 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class TEED_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ safe_steps=INPUT.INT(default=2, max=10),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Line Extractors"
+
+ def execute(self, image, safe_steps=2, resolution=512, **kwargs):
+ from custom_controlnet_aux.teed import TEDDetector
+
+ model = TEDDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution, safe_steps=safe_steps)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "TEEDPreprocessor": TEED_Preprocessor,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "TEED_Preprocessor": "TEED Soft-Edge Lines",
+}
\ No newline at end of file
diff --git a/node_wrappers/tile.py b/node_wrappers/tile.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f90ed91bf5d3254c63e34135afdfa3c89dcd1eb
--- /dev/null
+++ b/node_wrappers/tile.py
@@ -0,0 +1,73 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+
+
+class Tile_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ pyrUp_iters=INPUT.INT(default=3, min=1, max=10),
+ resolution=INPUT.RESOLUTION()
+ )
+
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/tile"
+
+ def execute(self, image, pyrUp_iters, resolution=512, **kwargs):
+ from custom_controlnet_aux.tile import TileDetector
+
+ return (common_annotator_call(TileDetector(), image, pyrUp_iters=pyrUp_iters, resolution=resolution),)
+
+class TTPlanet_TileGF_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
+ blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
+ radius=INPUT.INT(default=7, min=1, max=20),
+ eps=INPUT.FLOAT(default=0.01, min=0.001, max=0.1, step=0.001),
+ resolution=INPUT.RESOLUTION()
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/tile"
+
+ def execute(self, image, scale_factor, blur_strength, radius, eps, **kwargs):
+ from custom_controlnet_aux.tile import TTPlanet_Tile_Detector_GF
+
+ return (common_annotator_call(TTPlanet_Tile_Detector_GF(), image, scale_factor=scale_factor, blur_strength=blur_strength, radius=radius, eps=eps),)
+
+class TTPlanet_TileSimple_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(
+ scale_factor=INPUT.FLOAT(default=1.00, min=1.000, max=8.00),
+ blur_strength=INPUT.FLOAT(default=2.0, min=1.0, max=10.0),
+ )
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/tile"
+
+ def execute(self, image, scale_factor, blur_strength):
+ from custom_controlnet_aux.tile import TTPLanet_Tile_Detector_Simple
+
+ return (common_annotator_call(TTPLanet_Tile_Detector_Simple(), image, scale_factor=scale_factor, blur_strength=blur_strength),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "TilePreprocessor": Tile_Preprocessor,
+ "TTPlanet_TileGF_Preprocessor": TTPlanet_TileGF_Preprocessor,
+ "TTPlanet_TileSimple_Preprocessor": TTPlanet_TileSimple_Preprocessor
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "TilePreprocessor": "Tile",
+ "TTPlanet_TileGF_Preprocessor": "TTPlanet Tile GuidedFilter",
+ "TTPlanet_TileSimple_Preprocessor": "TTPlanet Tile Simple"
+}
diff --git a/node_wrappers/uniformer.py b/node_wrappers/uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a82971fb6092671bb395d936bf4055d65a7777
--- /dev/null
+++ b/node_wrappers/uniformer.py
@@ -0,0 +1,34 @@
+import os
+# Disable NPU device initialization and problematic MMCV ops to prevent RuntimeError
+os.environ['NPU_DEVICE_COUNT'] = '0'
+os.environ['MMCV_WITH_OPS'] = '0'
+
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class Uniformer_SemSegPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "semantic_segmentate"
+
+ CATEGORY = "ControlNet Preprocessors/Semantic Segmentation"
+
+ def semantic_segmentate(self, image, resolution=512):
+ from custom_controlnet_aux.uniformer import UniformerSegmentor
+
+ model = UniformerSegmentor.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "UniFormer-SemSegPreprocessor": Uniformer_SemSegPreprocessor,
+ "SemSegPreprocessor": Uniformer_SemSegPreprocessor,
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "UniFormer-SemSegPreprocessor": "UniFormer Segmentor",
+ "SemSegPreprocessor": "Semantic Segmentor (legacy, alias for UniFormer)",
+}
\ No newline at end of file
diff --git a/node_wrappers/unimatch.py b/node_wrappers/unimatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..20abcc7eaa20bd7b4e64c80ce4b60b6b0017a320
--- /dev/null
+++ b/node_wrappers/unimatch.py
@@ -0,0 +1,75 @@
+from ..utils import common_annotator_call
+import comfy.model_management as model_management
+import torch
+import numpy as np
+from einops import rearrange
+import torch.nn.functional as F
+
+class Unimatch_OptFlowPreprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": dict(
+ image=("IMAGE",),
+ ckpt_name=(
+ ["gmflow-scale1-mixdata.pth", "gmflow-scale2-mixdata.pth", "gmflow-scale2-regrefine6-mixdata.pth"],
+ {"default": "gmflow-scale2-regrefine6-mixdata.pth"}
+ ),
+ backward_flow=("BOOLEAN", {"default": False}),
+ bidirectional_flow=("BOOLEAN", {"default": False})
+ )
+ }
+
+ RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
+ RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
+ FUNCTION = "estimate"
+
+ CATEGORY = "ControlNet Preprocessors/Optical Flow"
+
+ def estimate(self, image, ckpt_name, backward_flow=False, bidirectional_flow=False):
+ assert len(image) > 1, "[Unimatch] Requiring as least two frames as an optical flow estimator. Only use this node on video input."
+ from custom_controlnet_aux.unimatch import UnimatchDetector
+ tensor_images = image
+ model = UnimatchDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
+ flows, vis_flows = [], []
+ for i in range(len(tensor_images) - 1):
+ image0, image1 = np.asarray(image[i:i+2].cpu() * 255., dtype=np.uint8)
+ flow, vis_flow = model(image0, image1, output_type="np", pred_bwd_flow=backward_flow, pred_bidir_flow=bidirectional_flow)
+ flows.append(torch.from_numpy(flow).float())
+ vis_flows.append(torch.from_numpy(vis_flow).float() / 255.)
+ del model
+ return (torch.stack(flows, dim=0), torch.stack(vis_flows, dim=0))
+
+class MaskOptFlow:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": dict(optical_flow=("OPTICAL_FLOW",), mask=("MASK",))
+ }
+
+ RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
+ RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
+ FUNCTION = "mask_opt_flow"
+
+ CATEGORY = "ControlNet Preprocessors/Optical Flow"
+
+ def mask_opt_flow(self, optical_flow, mask):
+ from custom_controlnet_aux.unimatch import flow_to_image
+ assert len(mask) >= len(optical_flow), f"Not enough masks to mask optical flow: {len(mask)} vs {len(optical_flow)}"
+ mask = mask[:optical_flow.shape[0]]
+ mask = F.interpolate(mask, optical_flow.shape[1:3])
+ mask = rearrange(mask, "n 1 h w -> n h w 1")
+ vis_flows = torch.stack([torch.from_numpy(flow_to_image(flow)).float() / 255. for flow in optical_flow.numpy()], dim=0)
+ vis_flows *= mask
+ optical_flow *= mask
+ return (optical_flow, vis_flows)
+
+
+NODE_CLASS_MAPPINGS = {
+ "Unimatch_OptFlowPreprocessor": Unimatch_OptFlowPreprocessor,
+ "MaskOptFlow": MaskOptFlow
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "Unimatch_OptFlowPreprocessor": "Unimatch Optical Flow",
+ "MaskOptFlow": "Mask Optical Flow (DragNUWA)"
+}
\ No newline at end of file
diff --git a/node_wrappers/zoe.py b/node_wrappers/zoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc82d7326ec36e4ac415b541c48544f0ef15b0e
--- /dev/null
+++ b/node_wrappers/zoe.py
@@ -0,0 +1,27 @@
+from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT
+import comfy.model_management as model_management
+
+class Zoe_Depth_Map_Preprocessor:
+ @classmethod
+ def INPUT_TYPES(s):
+ return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "execute"
+
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
+
+ def execute(self, image, resolution=512, **kwargs):
+ from custom_controlnet_aux.zoe import ZoeDetector
+
+ model = ZoeDetector.from_pretrained().to(model_management.get_torch_device())
+ out = common_annotator_call(model, image, resolution=resolution)
+ del model
+ return (out, )
+
+NODE_CLASS_MAPPINGS = {
+ "Zoe-DepthMapPreprocessor": Zoe_Depth_Map_Preprocessor
+}
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "Zoe-DepthMapPreprocessor": "Zoe Depth Map"
+}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..8c2391ef0d51037f1d7388b1e3956b3d94406a38
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,14 @@
+[project]
+name = "comfyui_controlnet_aux"
+description = "Plug-and-play ComfyUI node sets for making ControlNet hint images"
+
+version = "1.1.2"
+dependencies = ["torch", "importlib_metadata", "huggingface_hub", "scipy", "opencv-python>=4.7.0.72", "filelock", "numpy", "Pillow", "einops", "torchvision", "pyyaml", "scikit-image", "python-dateutil", "mediapipe", "svglib", "fvcore", "yapf", "omegaconf", "ftfy", "addict", "yacs", "trimesh[easy]", "albumentations", "scikit-learn", "matplotlib"]
+
+[project.urls]
+Repository = "https://github.com/Fannovel16/comfyui_controlnet_aux"
+
+[tool.comfy]
+PublisherId = "fannovel16"
+DisplayName = "comfyui_controlnet_aux"
+Icon = ""
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4140a38bd76a7c2fde43e060e4b70dd8883d68e6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+torch
+importlib_metadata
+huggingface_hub
+scipy
+opencv-python>=4.7.0.72
+filelock
+numpy
+Pillow
+einops
+torchvision
+pyyaml
+scikit-image
+python-dateutil
+mediapipe
+svglib
+fvcore
+yapf
+omegaconf
+ftfy
+addict
+yacs
+yapf
+trimesh[easy]
+albumentations
+scikit-learn
+matplotlib
diff --git a/search_hf_assets.py b/search_hf_assets.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3998259e1eca7f9078ada89e56c0451824b662d
--- /dev/null
+++ b/search_hf_assets.py
@@ -0,0 +1,56 @@
+from pathlib import Path
+import os
+import re
+#Thanks ChatGPT
+pattern = r'\bfrom_pretrained\(.*?pretrained_model_or_path\s*=\s*(.*?)(?:,|\))|filename\s*=\s*(.*?)(?:,|\))|(\w+_filename)\s*=\s*(.*?)(?:,|\))'
+aux_dir = Path(__file__).parent / 'src' / 'custom_controlnet_aux'
+VAR_DICT = dict(
+ HF_MODEL_NAME = "lllyasviel/Annotators",
+ DWPOSE_MODEL_NAME = "yzd-v/DWPose",
+ BDS_MODEL_NAME = "bdsqlsz/qinglong_controlnet-lllite",
+ DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image",
+ MESH_GRAPHORMER_MODEL_NAME = "hr16/ControlNet-HandRefiner-pruned",
+ SAM_MODEL_NAME = "dhkim2810/MobileSAM",
+ UNIMATCH_MODEL_NAME = "hr16/Unimatch",
+ DEPTH_ANYTHING_MODEL_NAME = "LiheYoung/Depth-Anything", #HF Space
+ DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge"
+)
+re_result_dict = {}
+for preprocc in os.listdir(aux_dir):
+ if preprocc in ["__pycache__", 'tests']: continue
+ if '.py' in preprocc: continue
+ f = open(aux_dir / preprocc / '__init__.py', 'r')
+ code = f.read()
+ matches = re.findall(pattern, code)
+ result = [match[0] or match[1] or match[3] for match in matches]
+ if not len(result):
+ print(preprocc)
+ continue
+ result = [el.replace("'", '').replace('"', '') for el in result]
+ result = [VAR_DICT.get(el, el) for el in result]
+ re_result_dict[preprocc] = result
+ f.close()
+
+for preprocc, re_result in re_result_dict.items():
+ model_name, filenames = re_result[0], re_result[1:]
+ print(f"* {preprocc}: ", end=' ')
+ assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
+ print(assests_md)
+
+preprocc = "dwpose"
+model_name, filenames = VAR_DICT['DWPOSE_MODEL_NAME'], ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
+print(f"* {preprocc}: ", end=' ')
+assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
+print(assests_md)
+
+preprocc = "yolo-nas"
+model_name, filenames = "hr16/yolo-nas-fp16", ["yolo_nas_l_fp16.onnx", "yolo_nas_m_fp16.onnx", "yolo_nas_s_fp16.onnx"]
+print(f"* {preprocc}: ", end=' ')
+assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
+print(assests_md)
+
+preprocc = "dwpose-torchscript"
+model_name, filenames = "hr16/DWPose-TorchScript-BatchSize5", ["dw-ll_ucoco_384_bs5.torchscript.pt", "rtmpose-m_ap10k_256_bs5.torchscript.pt"]
+print(f"* {preprocc}: ", end=' ')
+assests_md = ', '.join([f"[{model_name}/{filename}](https://huggingface.co/{model_name}/blob/main/{filename})" for filename in filenames])
+print(assests_md)
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1 @@
+#Dummy file ensuring this package will be recognized
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/__init__.py b/src/custom_controlnet_aux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8
--- /dev/null
+++ b/src/custom_controlnet_aux/__init__.py
@@ -0,0 +1 @@
+#Dummy file ensuring this package will be recognized
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/anime_face_segment/__init__.py b/src/custom_controlnet_aux/anime_face_segment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d8ace514883275a4021d0e9a0bc2640e59ca9a
--- /dev/null
+++ b/src/custom_controlnet_aux/anime_face_segment/__init__.py
@@ -0,0 +1,66 @@
+from .network import UNet
+from .util import seg2img
+import torch
+import os
+import cv2
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from einops import rearrange
+from .anime_segmentation import AnimeSegmentation
+import numpy as np
+
+class AnimeFaceSegmentor:
+ def __init__(self, model, seg_model):
+ self.model = model
+ self.seg_model = seg_model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators")
+ seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename)
+
+ model = UNet()
+ ckpt = torch.load(model_path, map_location="cpu")
+ model.load_state_dict(ckpt)
+ model.eval()
+
+ seg_model = AnimeSegmentation(seg_model_path)
+ seg_model.net.eval()
+ return cls(model, seg_model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.seg_model.net.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ with torch.no_grad():
+ if remove_background:
+ print(input_image.shape)
+ mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized
+ image_feed = torch.from_numpy(input_image).float().to(self.device)
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
+ image_feed = image_feed / 255
+ seg = self.model(image_feed).squeeze(dim=0)
+ result = seg2img(seg.cpu().detach().numpy())
+
+ detected_map = HWC3(result)
+ detected_map = remove_pad(detected_map)
+ if remove_background:
+ mask = remove_pad(mask)
+ H, W, C = detected_map.shape
+ tmp = np.zeros([H, W, C + 1])
+ tmp[:,:,:C] = detected_map
+ tmp[:,:,3:] = mask
+ detected_map = tmp
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map[..., :3])
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/anime_face_segment/anime_segmentation.py b/src/custom_controlnet_aux/anime_face_segment/anime_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4ab9c0effe4a84b3a76fc1b2cffd373109a035
--- /dev/null
+++ b/src/custom_controlnet_aux/anime_face_segment/anime_segmentation.py
@@ -0,0 +1,58 @@
+#https://github.com/SkyTNT/anime-segmentation/tree/main
+#Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
+import torch.nn as nn
+import torch
+from .isnet import ISNetDIS
+import numpy as np
+import cv2
+from comfy.model_management import get_torch_device
+DEVICE = get_torch_device()
+
+class AnimeSegmentation:
+ def __init__(self, ckpt_path):
+ super(AnimeSegmentation).__init__()
+ sd = torch.load(ckpt_path, map_location="cpu")
+ self.net = ISNetDIS()
+ #gt_encoder isn't used during inference
+ self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")})
+ self.net = self.net.to(DEVICE)
+ self.net.eval()
+
+ def get_mask(self, input_img, s=640):
+ input_img = (input_img / 255).astype(np.float32)
+ if s == 0:
+ img_input = np.transpose(input_img, (2, 0, 1))
+ img_input = img_input[np.newaxis, :]
+ tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
+ with torch.no_grad():
+ pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
+ pred = pred.cpu().numpy()[0]
+ pred = np.transpose(pred, (1, 2, 0))
+ #pred = pred[:, :, np.newaxis]
+ return pred
+
+ h, w = h0, w0 = input_img.shape[:-1]
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
+ ph, pw = s - h, s - w
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h))
+ img_input = np.transpose(img_input, (2, 0, 1))
+ img_input = img_input[np.newaxis, :]
+ tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
+ with torch.no_grad():
+ pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
+ pred = pred.cpu().numpy()[0]
+ pred = np.transpose(pred, (1, 2, 0))
+ pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
+ #pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]
+ pred = cv2.resize(pred, (w0, h0))
+ return pred
+
+ def __call__(self, np_img, img_size):
+ mask = self.get_mask(np_img, int(img_size))
+ np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8)
+ mask = (mask * 255).astype(np.uint8)
+ #np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8)
+ #mask = mask.repeat(3, axis=2)
+ return mask, np_img
+
diff --git a/src/custom_controlnet_aux/anime_face_segment/isnet.py b/src/custom_controlnet_aux/anime_face_segment/isnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0a504ecadf426358c8accee62d3f35129b0a11
--- /dev/null
+++ b/src/custom_controlnet_aux/anime_face_segment/isnet.py
@@ -0,0 +1,619 @@
+# Codes are borrowed from
+# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import models
+
+bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
+
+
+def muti_loss_fusion(preds, target):
+ loss0 = 0.0
+ loss = 0.0
+
+ for i in range(0, len(preds)):
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
+ tmp_target = F.interpolate(
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
+ )
+ loss = loss + bce_loss(preds[i], tmp_target)
+ else:
+ loss = loss + bce_loss(preds[i], target)
+ if i == 0:
+ loss0 = loss
+ return loss0, loss
+
+
+fea_loss = nn.MSELoss(reduction="mean")
+kl_loss = nn.KLDivLoss(reduction="mean")
+l1_loss = nn.L1Loss(reduction="mean")
+smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
+
+
+def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
+ loss0 = 0.0
+ loss = 0.0
+
+ for i in range(0, len(preds)):
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
+ tmp_target = F.interpolate(
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
+ )
+ loss = loss + bce_loss(preds[i], tmp_target)
+ else:
+ loss = loss + bce_loss(preds[i], target)
+ if i == 0:
+ loss0 = loss
+
+ for i in range(0, len(dfs)):
+ df = dfs[i]
+ fs_i = fs[i]
+ if mode == "MSE":
+ loss = loss + fea_loss(
+ df, fs_i
+ ) ### add the mse loss of features as additional constraints
+ elif mode == "KL":
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
+ elif mode == "MAE":
+ loss = loss + l1_loss(df, fs_i)
+ elif mode == "SmoothL1":
+ loss = loss + smooth_l1_loss(df, fs_i)
+
+ return loss0, loss
+
+
+class REBNCONV(nn.Module):
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
+ super(REBNCONV, self).__init__()
+
+ self.conv_s1 = nn.Conv2d(
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
+ )
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
+ self.relu_s1 = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ hx = x
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+ return xout
+
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src, tar):
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
+
+ return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
+ super(RSU7, self).__init__()
+
+ self.in_ch = in_ch
+ self.mid_ch = mid_ch
+ self.out_ch = out_ch
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+
+ hx = x
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+ hx = self.pool5(hx5)
+
+ hx6 = self.rebnconv6(hx)
+
+ hx7 = self.rebnconv7(hx6)
+
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
+ hx6dup = _upsample_like(hx6d, hx5)
+
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU6, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+ hx = self.pool4(hx4)
+
+ hx5 = self.rebnconv5(hx)
+
+ hx6 = self.rebnconv6(hx5)
+
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-5 ###
+class RSU5(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU5, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+ hx = self.pool3(hx3)
+
+ hx4 = self.rebnconv4(hx)
+
+ hx5 = self.rebnconv5(hx4)
+
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-4 ###
+class RSU4(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx = self.pool1(hx1)
+
+ hx2 = self.rebnconv2(hx)
+ hx = self.pool2(hx2)
+
+ hx3 = self.rebnconv3(hx)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+ return hx1d + hxin
+
+
+### RSU-4F ###
+class RSU4F(nn.Module):
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+ super(RSU4F, self).__init__()
+
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
+
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
+
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.rebnconvin(hx)
+
+ hx1 = self.rebnconv1(hxin)
+ hx2 = self.rebnconv2(hx1)
+ hx3 = self.rebnconv3(hx2)
+
+ hx4 = self.rebnconv4(hx3)
+
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
+
+ return hx1d + hxin
+
+
+class myrebnconv(nn.Module):
+ def __init__(
+ self,
+ in_ch=3,
+ out_ch=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ dilation=1,
+ groups=1,
+ ):
+ super(myrebnconv, self).__init__()
+
+ self.conv = nn.Conv2d(
+ in_ch,
+ out_ch,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.rl = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.rl(self.bn(self.conv(x)))
+
+
+class ISNetGTEncoder(nn.Module):
+ def __init__(self, in_ch=1, out_ch=1):
+ super(ISNetGTEncoder, self).__init__()
+
+ self.conv_in = myrebnconv(
+ in_ch, 16, 3, stride=2, padding=1
+ ) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
+
+ self.stage1 = RSU7(16, 16, 64)
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage2 = RSU6(64, 16, 64)
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage3 = RSU5(64, 32, 128)
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage4 = RSU4(128, 32, 256)
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage5 = RSU4F(256, 64, 512)
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage6 = RSU4F(512, 64, 512)
+
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+ @staticmethod
+ def compute_loss(args):
+ preds, targets = args
+ return muti_loss_fusion(preds, targets)
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.conv_in(hx)
+ # hx = self.pool_in(hxin)
+
+ # stage 1
+ hx1 = self.stage1(hxin)
+ hx = self.pool12(hx1)
+
+ # stage 2
+ hx2 = self.stage2(hx)
+ hx = self.pool23(hx2)
+
+ # stage 3
+ hx3 = self.stage3(hx)
+ hx = self.pool34(hx3)
+
+ # stage 4
+ hx4 = self.stage4(hx)
+ hx = self.pool45(hx4)
+
+ # stage 5
+ hx5 = self.stage5(hx)
+ hx = self.pool56(hx5)
+
+ # stage 6
+ hx6 = self.stage6(hx)
+
+ # side output
+ d1 = self.side1(hx1)
+ d1 = _upsample_like(d1, x)
+
+ d2 = self.side2(hx2)
+ d2 = _upsample_like(d2, x)
+
+ d3 = self.side3(hx3)
+ d3 = _upsample_like(d3, x)
+
+ d4 = self.side4(hx4)
+ d4 = _upsample_like(d4, x)
+
+ d5 = self.side5(hx5)
+ d5 = _upsample_like(d5, x)
+
+ d6 = self.side6(hx6)
+ d6 = _upsample_like(d6, x)
+
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
+
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
+
+
+class ISNetDIS(nn.Module):
+ def __init__(self, in_ch=3, out_ch=1):
+ super(ISNetDIS, self).__init__()
+
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage1 = RSU7(64, 32, 64)
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage2 = RSU6(64, 32, 128)
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage3 = RSU5(128, 64, 256)
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage4 = RSU4(256, 128, 512)
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage5 = RSU4F(512, 256, 512)
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.stage6 = RSU4F(512, 256, 512)
+
+ # decoder
+ self.stage5d = RSU4F(1024, 256, 512)
+ self.stage4d = RSU4(1024, 128, 256)
+ self.stage3d = RSU5(512, 64, 128)
+ self.stage2d = RSU6(256, 32, 64)
+ self.stage1d = RSU7(128, 16, 64)
+
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
+
+ @staticmethod
+ def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
+
+ @staticmethod
+ def compute_loss(args):
+ if len(args) == 3:
+ ds, dfs, labels = args
+ return muti_loss_fusion(ds, labels)
+ else:
+ ds, dfs, labels, fs = args
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
+
+ def forward(self, x):
+ hx = x
+
+ hxin = self.conv_in(hx)
+ hx = self.pool_in(hxin)
+
+ # stage 1
+ hx1 = self.stage1(hxin)
+ hx = self.pool12(hx1)
+
+ # stage 2
+ hx2 = self.stage2(hx)
+ hx = self.pool23(hx2)
+
+ # stage 3
+ hx3 = self.stage3(hx)
+ hx = self.pool34(hx3)
+
+ # stage 4
+ hx4 = self.stage4(hx)
+ hx = self.pool45(hx4)
+
+ # stage 5
+ hx5 = self.stage5(hx)
+ hx = self.pool56(hx5)
+
+ # stage 6
+ hx6 = self.stage6(hx)
+ hx6up = _upsample_like(hx6, hx5)
+
+ # -------------------- decoder --------------------
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+ hx5dup = _upsample_like(hx5d, hx4)
+
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+ hx4dup = _upsample_like(hx4d, hx3)
+
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+ hx3dup = _upsample_like(hx3d, hx2)
+
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+ hx2dup = _upsample_like(hx2d, hx1)
+
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+ # side output
+ d1 = self.side1(hx1d)
+ d1 = _upsample_like(d1, x)
+
+ d2 = self.side2(hx2d)
+ d2 = _upsample_like(d2, x)
+
+ d3 = self.side3(hx3d)
+ d3 = _upsample_like(d3, x)
+
+ d4 = self.side4(hx4d)
+ d4 = _upsample_like(d4, x)
+
+ d5 = self.side5(hx5d)
+ d5 = _upsample_like(d5, x)
+
+ d6 = self.side6(hx6)
+ d6 = _upsample_like(d6, x)
+
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
+
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/anime_face_segment/network.py b/src/custom_controlnet_aux/anime_face_segment/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c5929546c82181fe47f604048d9bfe5afb16031
--- /dev/null
+++ b/src/custom_controlnet_aux/anime_face_segment/network.py
@@ -0,0 +1,100 @@
+#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+from custom_controlnet_aux.util import custom_torch_download
+
+class UNet(nn.Module):
+ def __init__(self):
+ super(UNet, self).__init__()
+ self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
+
+ mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
+ mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
+ mob_blocks = mobilenet_v2.features
+
+ # Encoder
+ self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
+ mob_blocks[0],
+ mob_blocks[1]
+ )
+ self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
+ mob_blocks[2],
+ mob_blocks[3],
+ )
+ self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
+ mob_blocks[4],
+ mob_blocks[5],
+ mob_blocks[6],
+ )
+ self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
+ mob_blocks[7],
+ mob_blocks[8],
+ mob_blocks[9],
+ mob_blocks[10],
+ mob_blocks[11],
+ mob_blocks[12],
+ mob_blocks[13],
+ )
+ self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
+ mob_blocks[14],
+ mob_blocks[15],
+ mob_blocks[16],
+ )
+
+ # Decoder
+ self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(160, 96, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(96),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(32),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(24),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(16),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+
+ self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
+ nn.Softmax2d()
+ )
+
+ def forward(self, x):
+ e0 = self.en_block0(x)
+ e1 = self.en_block1(e0)
+ e2 = self.en_block2(e1)
+ e3 = self.en_block3(e2)
+ e4 = self.en_block4(e3)
+
+ d4 = self.de_block4(e4)
+ c4 = torch.cat((d4,e3),1)
+ d3 = self.de_block3(c4)
+ c3 = torch.cat((d3,e2),1)
+ d2 = self.de_block2(c3)
+ c2 =torch.cat((d2,e1),1)
+ d1 = self.de_block1(c2)
+ c1 = torch.cat((d1,e0),1)
+ y = self.de_block0(c1)
+
+ return y
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/anime_face_segment/util.py b/src/custom_controlnet_aux/anime_face_segment/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f3d22543675f9b098b8bc57d244c9e437d0636
--- /dev/null
+++ b/src/custom_controlnet_aux/anime_face_segment/util.py
@@ -0,0 +1,40 @@
+#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py
+#The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py
+import cv2 as cv
+import glob
+import numpy as np
+import os
+
+"""
+COLOR_BACKGROUND = (0,255,255)
+COLOR_HAIR = (255,0,0)
+COLOR_EYE = (0,0,255)
+COLOR_MOUTH = (255,255,255)
+COLOR_FACE = (0,255,0)
+COLOR_SKIN = (255,255,0)
+COLOR_CLOTHES = (255,0,255)
+"""
+COLOR_BACKGROUND = (255,255,0)
+COLOR_HAIR = (0,0,255)
+COLOR_EYE = (255,0,0)
+COLOR_MOUTH = (255,255,255)
+COLOR_FACE = (0,255,0)
+COLOR_SKIN = (0,255,255)
+COLOR_CLOTHES = (255,0,255)
+PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
+
+def img2seg(path):
+ src = cv.imread(path)
+ src = src.reshape(-1, 3)
+ seg_list = []
+ for color in PALETTE:
+ seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0))
+ dst = np.stack(seg_list,axis=1).reshape(512,512,7)
+
+ return dst.astype(np.float32)
+
+def seg2img(src):
+ src = np.moveaxis(src,0,2)
+ dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src]
+
+ return np.array(dst).astype(np.uint8)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/binary/__init__.py b/src/custom_controlnet_aux/binary/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34fd63fa2f29d310a5c1e9785e7997b6fdbd8227
--- /dev/null
+++ b/src/custom_controlnet_aux/binary/__init__.py
@@ -0,0 +1,38 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad
+
+class BinaryDetector:
+ def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ if "img" in kwargs:
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
+ input_image = kwargs.pop("img")
+
+ if input_image is None:
+ raise ValueError("input_image must be defined.")
+
+ if not isinstance(input_image, np.ndarray):
+ input_image = np.array(input_image, dtype=np.uint8)
+ output_type = output_type or "pil"
+ else:
+ output_type = output_type or "np"
+
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
+ if bin_threshold == 0 or bin_threshold == 255:
+ # Otsu's threshold
+ otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+ print("Otsu threshold:", otsu_threshold)
+ else:
+ _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
+
+ detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
+ detected_map = HWC3(remove_pad(255 - detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/canny/__init__.py b/src/custom_controlnet_aux/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f91c6dda3716c950fc55d6c0b6c53c81a32b7df
--- /dev/null
+++ b/src/custom_controlnet_aux/canny/__init__.py
@@ -0,0 +1,17 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
+
+class CannyDetector:
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
+ detected_map = HWC3(remove_pad(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/color/__init__.py b/src/custom_controlnet_aux/color/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2f010b6901128f49f926aa0c407a07c7d471e3
--- /dev/null
+++ b/src/custom_controlnet_aux/color/__init__.py
@@ -0,0 +1,37 @@
+import cv2
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate
+
+def cv2_resize_shortest_edge(image, size):
+ h, w = image.shape[:2]
+ if h < w:
+ new_h = size
+ new_w = int(round(w / h * size))
+ else:
+ new_w = size
+ new_h = int(round(h / w * size))
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
+ return resized_image
+
+def apply_color(img, res=512):
+ img = cv2_resize_shortest_edge(img, res)
+ h, w = img.shape[:2]
+
+ input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
+ input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
+ return input_img_color
+
+#Color T2I like multiples-of-64, upscale methods are fixed.
+class ColorDetector:
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image = HWC3(input_image)
+ detected_map = HWC3(apply_color(input_image, detect_resolution))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/densepose/__init__.py b/src/custom_controlnet_aux/densepose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6f9448c94a027156f4cc9db965f25beaf09f1fb
--- /dev/null
+++ b/src/custom_controlnet_aux/densepose/__init__.py
@@ -0,0 +1,66 @@
+import torchvision # Fix issue Unknown builtin op: torchvision::nms
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, DENSEPOSE_MODEL_NAME
+from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
+
+N_PART_LABELS = 24
+
+class DenseposeDetector:
+ def __init__(self, model):
+ self.dense_pose_estimation = model
+ self.device = "cpu"
+ self.result_visualizer = DensePoseMaskedColormapResultsVisualizer(
+ alpha=1,
+ data_extractor=_extract_i_from_iuvarr,
+ segm_extractor=_extract_i_from_iuvarr,
+ val_scale = 255.0 / N_PART_LABELS
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"):
+ torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename)
+ densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
+ return cls(densepose)
+
+ def to(self, device):
+ self.dense_pose_estimation.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ H, W = input_image.shape[:2]
+
+ hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
+ hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
+
+ input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w')
+
+ pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image)
+
+ extractor = densepose_chart_predictor_output_to_result_with_confidences
+ densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
+
+ if cmap=="viridis":
+ self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
+ hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
+ hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
+ hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
+ hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
+ else:
+ self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
+ hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results)
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
+
+ detected_map = remove_pad(HWC3(hint_image))
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+ return detected_map
diff --git a/src/custom_controlnet_aux/densepose/densepose.py b/src/custom_controlnet_aux/densepose/densepose.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e43b05fcd76efd5774485ca35e715e64acefdbe
--- /dev/null
+++ b/src/custom_controlnet_aux/densepose/densepose.py
@@ -0,0 +1,347 @@
+from typing import Tuple
+import math
+import numpy as np
+from enum import IntEnum
+from typing import List, Tuple, Union
+import torch
+from torch.nn import functional as F
+import logging
+import cv2
+
+Image = np.ndarray
+Boxes = torch.Tensor
+ImageSizeType = Tuple[int, int]
+_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
+IntTupleBox = Tuple[int, int, int, int]
+
+class BoxMode(IntEnum):
+ """
+ Enum of different ways to represent a box.
+ """
+
+ XYXY_ABS = 0
+ """
+ (x0, y0, x1, y1) in absolute floating points coordinates.
+ The coordinates in range [0, width or height].
+ """
+ XYWH_ABS = 1
+ """
+ (x0, y0, w, h) in absolute floating points coordinates.
+ """
+ XYXY_REL = 2
+ """
+ Not yet supported!
+ (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
+ """
+ XYWH_REL = 3
+ """
+ Not yet supported!
+ (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
+ """
+ XYWHA_ABS = 4
+ """
+ (xc, yc, w, h, a) in absolute floating points coordinates.
+ (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
+ """
+
+ @staticmethod
+ def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
+ """
+ Args:
+ box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
+ from_mode, to_mode (BoxMode)
+
+ Returns:
+ The converted box of the same type.
+ """
+ if from_mode == to_mode:
+ return box
+
+ original_type = type(box)
+ is_numpy = isinstance(box, np.ndarray)
+ single_box = isinstance(box, (list, tuple))
+ if single_box:
+ assert len(box) == 4 or len(box) == 5, (
+ "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
+ " where k == 4 or 5"
+ )
+ arr = torch.tensor(box)[None, :]
+ else:
+ # avoid modifying the input box
+ if is_numpy:
+ arr = torch.from_numpy(np.asarray(box)).clone()
+ else:
+ arr = box.clone()
+
+ assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
+ BoxMode.XYXY_REL,
+ BoxMode.XYWH_REL,
+ ], "Relative mode not yet supported!"
+
+ if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
+ assert (
+ arr.shape[-1] == 5
+ ), "The last dimension of input shape must be 5 for XYWHA format"
+ original_dtype = arr.dtype
+ arr = arr.double()
+
+ w = arr[:, 2]
+ h = arr[:, 3]
+ a = arr[:, 4]
+ c = torch.abs(torch.cos(a * math.pi / 180.0))
+ s = torch.abs(torch.sin(a * math.pi / 180.0))
+ # This basically computes the horizontal bounding rectangle of the rotated box
+ new_w = c * w + s * h
+ new_h = c * h + s * w
+
+ # convert center to top-left corner
+ arr[:, 0] -= new_w / 2.0
+ arr[:, 1] -= new_h / 2.0
+ # bottom-right corner
+ arr[:, 2] = arr[:, 0] + new_w
+ arr[:, 3] = arr[:, 1] + new_h
+
+ arr = arr[:, :4].to(dtype=original_dtype)
+ elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
+ original_dtype = arr.dtype
+ arr = arr.double()
+ arr[:, 0] += arr[:, 2] / 2.0
+ arr[:, 1] += arr[:, 3] / 2.0
+ angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
+ arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
+ else:
+ if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
+ arr[:, 2] += arr[:, 0]
+ arr[:, 3] += arr[:, 1]
+ elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
+ arr[:, 2] -= arr[:, 0]
+ arr[:, 3] -= arr[:, 1]
+ else:
+ raise NotImplementedError(
+ "Conversion from BoxMode {} to {} is not supported yet".format(
+ from_mode, to_mode
+ )
+ )
+
+ if single_box:
+ return original_type(arr.flatten().tolist())
+ if is_numpy:
+ return arr.numpy()
+ else:
+ return arr
+
+class MatrixVisualizer:
+ """
+ Base visualizer for matrix data
+ """
+
+ def __init__(
+ self,
+ inplace=True,
+ cmap=cv2.COLORMAP_PARULA,
+ val_scale=1.0,
+ alpha=0.7,
+ interp_method_matrix=cv2.INTER_LINEAR,
+ interp_method_mask=cv2.INTER_NEAREST,
+ ):
+ self.inplace = inplace
+ self.cmap = cmap
+ self.val_scale = val_scale
+ self.alpha = alpha
+ self.interp_method_matrix = interp_method_matrix
+ self.interp_method_mask = interp_method_mask
+
+ def visualize(self, image_bgr, mask, matrix, bbox_xywh):
+ self._check_image(image_bgr)
+ self._check_mask_matrix(mask, matrix)
+ if self.inplace:
+ image_target_bgr = image_bgr
+ else:
+ image_target_bgr = image_bgr * 0
+ x, y, w, h = [int(v) for v in bbox_xywh]
+ if w <= 0 or h <= 0:
+ return image_bgr
+ mask, matrix = self._resize(mask, matrix, w, h)
+ mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
+ matrix_scaled = matrix.astype(np.float32) * self.val_scale
+ _EPSILON = 1e-6
+ if np.any(matrix_scaled > 255 + _EPSILON):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
+ )
+ matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
+ matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
+ matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
+ image_target_bgr[y : y + h, x : x + w, :] = (
+ image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
+ )
+ return image_target_bgr.astype(np.uint8)
+
+ def _resize(self, mask, matrix, w, h):
+ if (w != mask.shape[1]) or (h != mask.shape[0]):
+ mask = cv2.resize(mask, (w, h), self.interp_method_mask)
+ if (w != matrix.shape[1]) or (h != matrix.shape[0]):
+ matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
+ return mask, matrix
+
+ def _check_image(self, image_rgb):
+ assert len(image_rgb.shape) == 3
+ assert image_rgb.shape[2] == 3
+ assert image_rgb.dtype == np.uint8
+
+ def _check_mask_matrix(self, mask, matrix):
+ assert len(matrix.shape) == 2
+ assert len(mask.shape) == 2
+ assert mask.dtype == np.uint8
+
+class DensePoseResultsVisualizer:
+ def visualize(
+ self,
+ image_bgr: Image,
+ results,
+ ) -> Image:
+ context = self.create_visualization_context(image_bgr)
+ for i, result in enumerate(results):
+ boxes_xywh, labels, uv = result
+ iuv_array = torch.cat(
+ (labels[None].type(torch.float32), uv * 255.0)
+ ).type(torch.uint8)
+ self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
+ image_bgr = self.context_to_image_bgr(context)
+ return image_bgr
+
+ def create_visualization_context(self, image_bgr: Image):
+ return image_bgr
+
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
+ pass
+
+ def context_to_image_bgr(self, context):
+ return context
+
+ def get_image_bgr_from_context(self, context):
+ return context
+
+class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
+ def __init__(
+ self,
+ data_extractor,
+ segm_extractor,
+ inplace=True,
+ cmap=cv2.COLORMAP_PARULA,
+ alpha=0.7,
+ val_scale=1.0,
+ **kwargs,
+ ):
+ self.mask_visualizer = MatrixVisualizer(
+ inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
+ )
+ self.data_extractor = data_extractor
+ self.segm_extractor = segm_extractor
+
+ def context_to_image_bgr(self, context):
+ return context
+
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
+ image_bgr = self.get_image_bgr_from_context(context)
+ matrix = self.data_extractor(iuv_arr)
+ segm = self.segm_extractor(iuv_arr)
+ mask = np.zeros(matrix.shape, dtype=np.uint8)
+ mask[segm > 0] = 1
+ image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
+
+
+def _extract_i_from_iuvarr(iuv_arr):
+ return iuv_arr[0, :, :]
+
+
+def _extract_u_from_iuvarr(iuv_arr):
+ return iuv_arr[1, :, :]
+
+
+def _extract_v_from_iuvarr(iuv_arr):
+ return iuv_arr[2, :, :]
+
+def make_int_box(box: torch.Tensor) -> IntTupleBox:
+ int_box = [0, 0, 0, 0]
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
+ return int_box[0], int_box[1], int_box[2], int_box[3]
+
+def densepose_chart_predictor_output_to_result_with_confidences(
+ boxes: Boxes,
+ coarse_segm,
+ fine_segm,
+ u, v
+
+):
+ boxes_xyxy_abs = boxes.clone()
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ box_xywh = make_int_box(boxes_xywh_abs[0])
+
+ labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
+ uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
+ confidences = []
+ return box_xywh, labels, uv
+
+def resample_fine_and_coarse_segm_tensors_to_bbox(
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
+):
+ """
+ Resample fine and coarse segmentation tensors to the given
+ bounding box and derive labels for each pixel of the bounding box
+
+ Args:
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
+ corner coordinates, width (W) and height (H)
+ Return:
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
+ """
+ x, y, w, h = box_xywh_abs
+ w = max(int(w), 1)
+ h = max(int(h), 1)
+ # coarse segmentation
+ coarse_segm_bbox = F.interpolate(
+ coarse_segm,
+ (h, w),
+ mode="bilinear",
+ align_corners=False,
+ ).argmax(dim=1)
+ # combined coarse and fine segmentation
+ labels = (
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
+ * (coarse_segm_bbox > 0).long()
+ )
+ return labels
+
+def resample_uv_tensors_to_bbox(
+ u: torch.Tensor,
+ v: torch.Tensor,
+ labels: torch.Tensor,
+ box_xywh_abs: IntTupleBox,
+) -> torch.Tensor:
+ """
+ Resamples U and V coordinate estimates for the given bounding box
+
+ Args:
+ u (tensor [1, C, H, W] of float): U coordinates
+ v (tensor [1, C, H, W] of float): V coordinates
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
+ outputs for the given bounding box
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
+ Return:
+ Resampled U and V coordinates - a tensor [2, H, W] of float
+ """
+ x, y, w, h = box_xywh_abs
+ w = max(int(w), 1)
+ h = max(int(h), 1)
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
+ for part_id in range(1, u_bbox.size(1)):
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
+ return uv
+
diff --git a/src/custom_controlnet_aux/depth_anything/__init__.py b/src/custom_controlnet_aux/depth_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5240bbdddf7ba3ac2cbfe8a446a5fa2a0106d079
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything/__init__.py
@@ -0,0 +1 @@
+from .transformers import DepthAnythingDetector
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/depth_anything/transformers.py b/src/custom_controlnet_aux/depth_anything/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a0ff6358ce57678ce22a2da703cc2f3e171afe
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything/transformers.py
@@ -0,0 +1,75 @@
+"""
+Modern DepthAnything implementation using HuggingFace transformers.
+Replaces legacy torch.hub.load DINOv2 backbone with transformers pipeline.
+"""
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import pipeline
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad
+
+class DepthAnythingDetector:
+ """DepthAnything depth estimation using HuggingFace transformers."""
+
+ def __init__(self, model_name="LiheYoung/depth-anything-large-hf"):
+ """Initialize DepthAnything with specified model."""
+ self.pipe = pipeline(task="depth-estimation", model=model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_vitl14.pth"):
+ """Create DepthAnything from pretrained model, mapping legacy names to HuggingFace models."""
+
+ # Map legacy checkpoint names to modern HuggingFace models
+ model_mapping = {
+ "depth_anything_vitl14.pth": "LiheYoung/depth-anything-large-hf",
+ "depth_anything_vitb14.pth": "LiheYoung/depth-anything-base-hf",
+ "depth_anything_vits14.pth": "LiheYoung/depth-anything-small-hf"
+ }
+
+ model_name = model_mapping.get(filename, "LiheYoung/depth-anything-large-hf")
+ return cls(model_name=model_name)
+
+ def to(self, device):
+ """Move model to specified device."""
+ self.pipe.model = self.pipe.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ """Perform depth estimation on input image."""
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ if isinstance(input_image, np.ndarray):
+ pil_image = Image.fromarray(input_image)
+ else:
+ pil_image = input_image
+
+ with torch.no_grad():
+ result = self.pipe(pil_image)
+ depth = result["depth"]
+
+ if isinstance(depth, Image.Image):
+ depth_array = np.array(depth, dtype=np.float32)
+ else:
+ depth_array = np.array(depth)
+
+ # Normalize depth values to 0-255 range
+ depth_min = depth_array.min()
+ depth_max = depth_array.max()
+ if depth_max > depth_min:
+ depth_array = (depth_array - depth_min) / (depth_max - depth_min) * 255.0
+ else:
+ depth_array = np.zeros_like(depth_array)
+
+ depth_image = depth_array.astype(np.uint8)
+
+ detected_map = remove_pad(HWC3(depth_image))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/depth_anything_v2/__init__.py b/src/custom_controlnet_aux/depth_anything_v2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..924e7bc7f6ae44ecaa1a365a4983bc72bbce71a4
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/__init__.py
@@ -0,0 +1,56 @@
+import numpy as np
+import torch
+from einops import repeat
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DEPTH_ANYTHING_V2_MODEL_NAME_DICT
+from custom_controlnet_aux.depth_anything_v2.dpt import DepthAnythingV2
+import cv2
+import torch.nn.functional as F
+
+
+# https://github.com/DepthAnything/Depth-Anything-V2/blob/main/app.py
+model_configs = {
+ 'depth_anything_v2_vits.pth': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'depth_anything_v2_vitb.pth': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'depth_anything_v2_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'depth_anything_v2_vitg.pth': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]},
+ 'depth_anything_v2_metric_vkitti_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'depth_anything_v2_metric_hypersim_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+}
+
+class DepthAnythingV2Detector:
+ def __init__(self, model, filename):
+ self.model = model
+ self.device = "cpu"
+ self.filename = filename
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_v2_vits.pth"):
+ if pretrained_model_or_path is None:
+ pretrained_model_or_path = DEPTH_ANYTHING_V2_MODEL_NAME_DICT[filename]
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ model = DepthAnythingV2(**model_configs[filename])
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
+ model = model.eval()
+ return cls(model, filename)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", max_depth=20.0, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+
+ depth = self.model.infer_image(cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR), input_size=518, max_depth=max_depth)
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth = depth.astype(np.uint8)
+ if 'metric' in self.filename:
+ depth = 255 - depth
+
+ detected_map = repeat(depth, "h w -> h w 3")
+ detected_map, remove_pad = resize_image_with_pad(detected_map, detect_resolution, upscale_method)
+ detected_map = remove_pad(detected_map)
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..78e99975d317fbaac81d0bab5401ce91ccfff070
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2.py
@@ -0,0 +1,415 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from custom_controlnet_aux.depth_anything_v2.dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+ # w0, h0 = w0 + 0.1, h0 + 0.1
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
+ mode="bicubic",
+ antialias=self.interpolate_antialias
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def DINOv2(model_name):
+ model_zoo = {
+ "vits": vit_small,
+ "vitb": vit_base,
+ "vitl": vit_large,
+ "vitg": vit_giant2
+ }
+
+ return model_zoo[model_name](
+ img_size=518,
+ patch_size=14,
+ init_values=1.0,
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
+ block_chunks=0,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1
+ )
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/src/custom_controlnet_aux/depth_anything_v2/dpt.py b/src/custom_controlnet_aux/depth_anything_v2/dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..630ae6936895db1d0d8806b30a0815fbecf83443
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/dpt.py
@@ -0,0 +1,220 @@
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from custom_controlnet_aux.depth_anything_v2.dinov2 import DINOv2
+from custom_controlnet_aux.depth_anything_v2.util.blocks import FeatureFusionBlock, _make_scratch
+from custom_controlnet_aux.depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True),
+ nn.Identity(),
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
+
+class DepthAnythingV2(nn.Module):
+ def __init__(
+ self,
+ encoder='vitl',
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ use_bn=False,
+ use_clstoken=False
+ ):
+ super(DepthAnythingV2, self).__init__()
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitb': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23],
+ 'vitg': [9, 19, 29, 39]
+ }
+
+ self.encoder = encoder
+ self.pretrained = DINOv2(model_name=encoder)
+
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
+
+ def forward(self, x, max_depth):
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
+
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
+
+ depth = self.depth_head(features, patch_h, patch_w) * max_depth
+
+ return depth.squeeze(1)
+
+ @torch.no_grad()
+ def infer_image(self, raw_image, input_size=518, max_depth=20.0):
+ image, (h, w) = self.image2tensor(raw_image, input_size)
+
+ depth = self.forward(image, max_depth)
+
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
+
+ return depth.cpu().numpy()
+
+ def image2tensor(self, raw_image, input_size=518):
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ h, w = raw_image.shape[:2]
+
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
+
+ image = transform({'image': image})['image']
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+ image = image.to(DEVICE)
+
+ return image, (h, w)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/depth_anything_v2/util/blocks.py b/src/custom_controlnet_aux/depth_anything_v2/util/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/util/blocks.py
@@ -0,0 +1,148 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/src/custom_controlnet_aux/depth_anything_v2/util/transform.py b/src/custom_controlnet_aux/depth_anything_v2/util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73
--- /dev/null
+++ b/src/custom_controlnet_aux/depth_anything_v2/util/transform.py
@@ -0,0 +1,158 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/__init__.py b/src/custom_controlnet_aux/diffusion_edge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d6010f8d4b16327bb78a3baf2dee27581cf1ef
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/__init__.py
@@ -0,0 +1,40 @@
+from custom_controlnet_aux.diffusion_edge.model import DiffusionEdge, prepare_args
+import numpy as np
+import torch
+from einops import rearrange
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DIFFUSION_EDGE_MODEL_NAME
+
+class DiffusionEdgeDetector:
+ def __init__(self, model):
+ self.model = model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="diffusion_edge_indoor.pt"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ model = DiffusionEdge(prepare_args(model_path))
+ return cls(model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, patch_batch_size=8, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ with torch.no_grad():
+ input_image = rearrange(torch.from_numpy(input_image), "h w c -> 1 c h w")
+ input_image = input_image.float() / 255.
+ line = self.model(input_image, patch_batch_size)
+ line = rearrange(line, "1 c h w -> h w c")
+
+ detected_map = line.cpu().numpy().__mul__(255.).astype(np.uint8)
+ detected_map = remove_pad(HWC3(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/default.yaml b/src/custom_controlnet_aux/diffusion_edge/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f13fbd63921bee108c5a015a2aae2a44122fb69
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/default.yaml
@@ -0,0 +1,74 @@
+model:
+ model_type: const_sde
+ model_name: cond_unet
+ image_size: [320, 320]
+ input_keys: ['image', 'cond']
+ ckpt_path:
+ ignore_keys: [ ]
+ only_model: False
+ timesteps: 1000
+ train_sample: -1
+ sampling_timesteps: 1
+ loss_type: l2
+ objective: pred_noise
+ start_dist: normal
+ perceptual_weight: 0
+ scale_factor: 0.3
+ scale_by_std: True
+ default_scale: True
+ scale_by_softsign: False
+ eps: !!float 1e-4
+ weighting_loss: False
+ first_stage:
+ embed_dim: 3
+ lossconfig:
+ disc_start: 50001
+ kl_weight: 0.000001
+ disc_weight: 0.5
+ disc_in_channels: 1
+ ddconfig:
+ double_z: True
+ z_channels: 3
+ resolution: [ 320, 320 ]
+ in_channels: 1
+ out_ch: 1
+ ch: 128
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ ckpt_path:
+ unet:
+ dim: 128
+ cond_net: swin
+ without_pretrain: False
+ channels: 3
+ out_mul: 1
+ dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults)
+ cond_in_dim: 3
+ cond_dim: 128
+ cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults)
+ # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
+ # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ]
+ window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
+ window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ]
+ fourier_scale: 16
+ cond_pe: False
+ num_pos_feats: 128
+ cond_feature_size: [ 80, 80 ]
+
+data:
+ name: edge
+ img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test'
+ augment_horizontal_flip: True
+ batch_size: 8
+ num_workers: 4
+
+sampler:
+ sample_type: "slide"
+ stride: [240, 240]
+ batch_size: 1
+ sample_num: 300
+ use_ema: True
+ save_folder:
+ ckpt_path:
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc97bb8da2613d82f5756f5bc9476e4fda1fd95
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py
@@ -0,0 +1 @@
+# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..859a62555ed552856a191d95c5e4c47fa7e3002b
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py
@@ -0,0 +1,598 @@
+import torch
+import torchvision.transforms as T
+import torch.utils.data as data
+import torch.nn as nn
+from pathlib import Path
+from functools import partial
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import exists, convert_image_to_fn, normalize_to_neg_one_to_one
+from PIL import Image, ImageDraw
+import torch.nn.functional as F
+import math
+import torchvision.transforms.functional as F2
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+from typing import Any, Callable, Optional, Tuple
+import os
+import pickle
+import numpy as np
+import copy
+import albumentations
+from torchvision.transforms.functional import InterpolationMode
+
+def get_imgs_list(imgs_dir):
+ imgs_list = os.listdir(imgs_dir)
+ imgs_list.sort()
+ return [os.path.join(imgs_dir, f) for f in imgs_list if f.endswith('.jpg') or f.endswith('.JPG')or f.endswith('.png') or f.endswith('.pgm') or f.endswith('.ppm')]
+
+
+def fit_img_postfix(img_path):
+ if not os.path.exists(img_path) and img_path.endswith(".jpg"):
+ img_path = img_path[:-4] + ".png"
+ if not os.path.exists(img_path) and img_path.endswith(".png"):
+ img_path = img_path[:-4] + ".jpg"
+ return img_path
+
+
+class AdaptEdgeDataset(data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ # mask_folder,
+ image_size,
+ exts = ['png', 'jpg'],
+ augment_horizontal_flip = False,
+ convert_image_to = None,
+ normalize_to_neg_one_to_one=True,
+ split='train',
+ # inter_type='bicubic',
+ # down=4,
+ threshold=0.3, use_uncertainty=False
+ ):
+ super().__init__()
+ # self.img_folder = Path(img_folder)
+ # self.edge_folder = Path(os.path.join(data_root, f'gt_imgs'))
+ # self.img_folder = Path(os.path.join(data_root, f'imgs'))
+ # self.edge_folder = Path(os.path.join(data_root, "edge", "aug"))
+ # self.img_folder = Path(os.path.join(data_root, "image", "aug"))
+ self.data_root = data_root
+ self.image_size = image_size
+
+ # self.edge_paths = [p for ext in exts for p in self.edge_folder.rglob(f'*.{ext}')]
+ # self.img_paths = [(self.img_folder / item.parent.name / f'{item.stem}.jpg') for item in self.edge_paths]
+ # self.img_paths = [(self.img_folder / f'{item.stem}.jpg') for item in self.edge_paths]
+
+ self.threshold = threshold * 256
+ self.use_uncertainty = use_uncertainty
+ self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one
+
+ maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity()
+
+ # self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one
+ # self.random_crop = RandomCrop(size=image_size)
+ # self.transform = Compose([
+ # # Lambda(maybe_convert_fn),
+ # # Resize(image_size, interpolation=3, interpolation2=0),
+ # Resize(image_size, interpolation=InterpolationMode.BILINEAR, interpolation2=InterpolationMode.NEAREST),
+ # RandomHorizontalFlip() if augment_horizontal_flip else Identity(),
+ # # RandomCrop(image_size),
+ # ToTensor()
+ # ])
+ self.data_list = self.build_list()
+
+ self.transform = transforms.Compose([
+ # Resize(self.image_size, interpolation=InterpolationMode.BILINEAR, interpolation2=InterpolationMode.NEAREST),
+ transforms.ToTensor()])
+
+ def __len__(self):
+ return len(self.data_list)
+
+
+ def read_img(self, image_path):
+ with open(image_path, 'rb') as f:
+ img = Image.open(f)
+ img = img.convert('RGB')
+
+ raw_width, raw_height = img.size
+ # width = int(raw_width / 32) * 32
+ # height = int(raw_height / 32) * 32
+ # img = img.resize((width, height), Image.Resampling.BILINEAR)
+ # # print("img.size:", img.size)
+ # img = self.transform(img)
+
+ return img, (raw_width, raw_height)
+
+ def read_lb(self, lb_path):
+ lb_data = Image.open(lb_path)
+
+ width, height = lb_data.size
+ width = int(width / 32) * 32
+ height = int(height / 32) * 32
+ lb_data = lb_data.resize((width, height), Image.Resampling.BILINEAR)
+ # print("lb_data.size:", lb_data.size)
+ lb = np.array(lb_data, dtype=np.float32)
+ if lb.ndim == 3:
+ lb = np.squeeze(lb[:, :, 0])
+ assert lb.ndim == 2
+ threshold = self.threshold
+ lb = lb[np.newaxis, :, :]
+
+ lb[lb == 0] = 0
+
+ # ---------- important ----------
+ if self.use_uncertainty:
+ lb[np.logical_and(lb > 0, lb < threshold)] = 2
+ else:
+ lb[np.logical_and(lb > 0, lb < threshold)] /= 255.
+
+ lb[lb >= threshold] = 1
+ return lb
+
+ def build_list(self):
+ data_root = os.path.abspath(self.data_root)
+ images_path = os.path.join(data_root, 'image', "raw")
+ labels_path = os.path.join(data_root, 'edge', "raw")
+
+ samples = []
+ for directory_name in os.listdir(images_path):
+ image_directories = os.path.join(images_path, directory_name)
+ for file_name_ext in os.listdir(image_directories):
+ file_name = os.path.basename(file_name_ext)
+ image_path = fit_img_postfix(os.path.join(images_path, directory_name, file_name))
+ lb_path = fit_img_postfix(os.path.join(labels_path, directory_name, file_name))
+ samples.append((image_path, lb_path))
+ return samples
+
+ def __getitem__(self, index):
+ img_path, edge_path = self.data_list[index]
+ # edge_path = self.edge_paths[index]
+ # img_path = self.img_paths[index]
+ img_name = os.path.basename(img_path)
+
+ img, raw_size = self.read_img(img_path)
+ edge = self.read_lb(edge_path)
+
+ # print("-------hhhhhhhhhhhhh--------:", img.shape, edge.shape)
+ # edge = Image.open(edge_path).convert('L')
+ # # default to score-sde preprocessing
+ # mask = Image.open(img_path).convert('RGB')
+ # edge, img = self.transform(edge, mask)
+ if self.normalize_to_neg_one_to_one: # transform to [-1, 1]
+ edge = normalize_to_neg_one_to_one(edge)
+ img = normalize_to_neg_one_to_one(img)
+ return {'image': edge, 'cond': img, 'raw_size': raw_size, 'img_name': img_name}
+
+class EdgeDataset(data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ # mask_folder,
+ image_size,
+ exts = ['png', 'jpg'],
+ augment_horizontal_flip = True,
+ convert_image_to = None,
+ normalize_to_neg_one_to_one=True,
+ split='train',
+ # inter_type='bicubic',
+ # down=4,
+ threshold=0.3, use_uncertainty=False, cfg={}
+ ):
+ super().__init__()
+ # self.img_folder = Path(img_folder)
+ # self.edge_folder = Path(os.path.join(data_root, f'gt_imgs'))
+ # self.img_folder = Path(os.path.join(data_root, f'imgs'))
+ # self.edge_folder = Path(os.path.join(data_root, "edge", "aug"))
+ # self.img_folder = Path(os.path.join(data_root, "image", "aug"))
+ self.data_root = data_root
+ self.image_size = image_size
+
+ # self.edge_paths = [p for ext in exts for p in self.edge_folder.rglob(f'*.{ext}')]
+ # self.img_paths = [(self.img_folder / item.parent.name / f'{item.stem}.jpg') for item in self.edge_paths]
+ # self.img_paths = [(self.img_folder / f'{item.stem}.jpg') for item in self.edge_paths]
+
+ self.threshold = threshold * 255
+ self.use_uncertainty = use_uncertainty
+ self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one
+
+ maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity()
+
+ self.data_list = self.build_list()
+
+ # self.transform = Compose([
+ # Resize(image_size),
+ # RandomHorizontalFlip() if augment_horizontal_flip else Identity(),
+ # ToTensor()
+ # ])
+ crop_type = cfg.get('crop_type') if 'crop_type' in cfg else 'rand_crop'
+ if crop_type == 'rand_crop':
+ self.transform = Compose([
+ RandomCrop(image_size),
+ RandomHorizontalFlip() if augment_horizontal_flip else Identity(),
+ ToTensor()
+ ])
+ elif crop_type == 'rand_resize_crop':
+ self.transform = Compose([
+ RandomResizeCrop(image_size),
+ RandomHorizontalFlip() if augment_horizontal_flip else Identity(),
+ ToTensor()
+ ])
+ print("crop_type:", crop_type)
+
+ def __len__(self):
+ return len(self.data_list)
+
+
+ def read_img(self, image_path):
+ with open(image_path, 'rb') as f:
+ img = Image.open(f)
+ img = img.convert('RGB')
+
+ raw_width, raw_height = img.size
+ # width = int(raw_width / 32) * 32
+ # height = int(raw_height / 32) * 32
+ # img = img.resize((width, height), Image.Resampling.BILINEAR)
+ # # print("img.size:", img.size)
+ # img = self.transform(img)
+
+ return img, (raw_width, raw_height)
+
+ def read_lb(self, lb_path):
+ lb_data = Image.open(lb_path).convert('L')
+ lb = np.array(lb_data).astype(np.float32)
+ # width, height = lb_data.size
+ # width = int(width / 32) * 32
+ # height = int(height / 32) * 32
+ # lb_data = lb_data.resize((width, height), Image.Resampling.BILINEAR)
+ # print("lb_data.size:", lb_data.size)
+ # lb = np.array(lb_data, dtype=np.float32)
+ # if lb.ndim == 3:
+ # lb = np.squeeze(lb[:, :, 0])
+ # assert lb.ndim == 2
+ threshold = self.threshold
+ # lb = lb[np.newaxis, :, :]
+ # lb[lb == 0] = 0
+
+ # ---------- important ----------
+ # if self.use_uncertainty:
+ # lb[np.logical_and(lb > 0, lb < threshold)] = 2
+ # else:
+ # lb[np.logical_and(lb > 0, lb < threshold)] /= 255.
+
+ lb[lb >= threshold] = 255
+ lb = Image.fromarray(lb.astype(np.uint8))
+ return lb
+
+ def build_list(self):
+ data_root = os.path.abspath(self.data_root)
+ images_path = os.path.join(data_root, 'image')
+ labels_path = os.path.join(data_root, 'edge')
+
+ samples = []
+ for directory_name in os.listdir(images_path):
+ image_directories = os.path.join(images_path, directory_name)
+ for file_name_ext in os.listdir(image_directories):
+ file_name = os.path.basename(file_name_ext)
+ image_path = fit_img_postfix(os.path.join(images_path, directory_name, file_name))
+ lb_path = fit_img_postfix(os.path.join(labels_path, directory_name, file_name))
+ samples.append((image_path, lb_path))
+ return samples
+
+ def __getitem__(self, index):
+ img_path, edge_path = self.data_list[index]
+ # edge_path = self.edge_paths[index]
+ # img_path = self.img_paths[index]
+ img_name = os.path.basename(img_path)
+
+ img, raw_size = self.read_img(img_path)
+ edge = self.read_lb(edge_path)
+ img, edge = self.transform(img, edge)
+
+ # print("-------hhhhhhhhhhhhh--------:", img.shape, edge.shape)
+ # edge = Image.open(edge_path).convert('L')
+ # # default to score-sde preprocessing
+ # mask = Image.open(img_path).convert('RGB')
+ # edge, img = self.transform(edge, mask)
+ if self.normalize_to_neg_one_to_one: # transform to [-1, 1]
+ edge = normalize_to_neg_one_to_one(edge)
+ img = normalize_to_neg_one_to_one(img)
+ return {'image': edge, 'cond': img, 'raw_size': raw_size, 'img_name': img_name}
+
+class EdgeDatasetTest(data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ # mask_folder,
+ image_size,
+ exts = ['png', 'jpg'],
+ convert_image_to = None,
+ normalize_to_neg_one_to_one=True,
+ ):
+ super().__init__()
+
+ self.data_root = data_root
+ self.image_size = image_size
+ self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one
+
+ maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity()
+
+ self.data_list = self.build_list()
+
+ self.transform = Compose([
+ ToTensor()
+ ])
+
+ def __len__(self):
+ return len(self.data_list)
+
+
+ def read_img(self, image_path):
+ with open(image_path, 'rb') as f:
+ img = Image.open(f)
+ img = img.convert('RGB')
+
+ raw_width, raw_height = img.size
+
+
+ return img, (raw_width, raw_height)
+
+ def read_lb(self, lb_path):
+ lb_data = Image.open(lb_path).convert('L')
+ lb = np.array(lb_data).astype(np.float32)
+
+ threshold = self.threshold
+
+
+ lb[lb >= threshold] = 255
+ lb = Image.fromarray(lb.astype(np.uint8))
+ return lb
+
+ def build_list(self):
+ data_root = os.path.abspath(self.data_root)
+ # images_path = os.path.join(data_root)
+ images_path = data_root
+ samples = get_imgs_list(images_path)
+ return samples
+
+ def __getitem__(self, index):
+ img_path = self.data_list[index]
+ # edge_path = self.edge_paths[index]
+ # img_path = self.img_paths[index]
+ img_name = os.path.basename(img_path)
+
+ img, raw_size = self.read_img(img_path)
+
+ img = self.transform(img)
+ if self.normalize_to_neg_one_to_one: # transform to [-1, 1]
+ img = normalize_to_neg_one_to_one(img)
+ return {'cond': img, 'raw_size': raw_size, 'img_name': img_name}
+
+
+class Identity(nn.Identity):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Shape:
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
+ - Output: :math:`(*)`, same shape as the input.
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 20])
+
+ """
+ def __init__(self, *args, **kwargs):
+ super(Identity, self).__init__(*args, **kwargs)
+
+ def forward(self, input, target):
+ return input, target
+
+class Resize(T.Resize):
+ def __init__(self, size, interpolation2=None, **kwargs):
+ super().__init__(size, **kwargs)
+ if interpolation2 is None:
+ self.interpolation2 = self.interpolation
+ else:
+ self.interpolation2 = interpolation2
+
+ def forward(self, img, target=None):
+ if target is None:
+ img = F2.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
+ return img
+ else:
+ img = F2.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
+ target = F2.resize(target, self.size, self.interpolation2, self.max_size, self.antialias)
+ return img, target
+
+class RandomHorizontalFlip(T.RandomHorizontalFlip):
+ def __init__(self, p=0.5):
+ super().__init__(p)
+
+ def forward(self, img, target=None):
+ if target is None:
+ if torch.rand(1) < self.p:
+ img = F2.hflip(img)
+ return img
+ else:
+ if torch.rand(1) < self.p:
+ img = F2.hflip(img)
+ target = F2.hflip(target)
+ return img, target
+
+class CenterCrop(T.CenterCrop):
+ def __init__(self, size):
+ super().__init__(size)
+
+ def forward(self, img, target=None):
+ if target is None:
+ img = F2.center_crop(img, self.size)
+ return img
+ else:
+ img = F2.center_crop(img, self.size)
+ target = F2.center_crop(target, self.size)
+ return img, target
+
+class RandomCrop(T.RandomCrop):
+ def __init__(self, size, **kwargs):
+ super().__init__(size, **kwargs)
+
+ def single_forward(self, img, i, j, h, w):
+ if self.padding is not None:
+ img = F2.pad(img, self.padding, self.fill, self.padding_mode)
+ width, height = F2.get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F2.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F2.pad(img, padding, self.fill, self.padding_mode)
+
+ return F2.crop(img, i, j, h, w)
+
+ def forward(self, img, target=None):
+ i, j, h, w = self.get_params(img, self.size)
+ if target is None:
+ img = self.single_forward(img, i, j, h, w)
+ return img
+ else:
+ img = self.single_forward(img, i, j, h, w)
+ target = self.single_forward(target, i, j, h, w)
+ return img, target
+
+class RandomResizeCrop(T.RandomResizedCrop):
+ def __init__(self, size, scale=(0.25, 1.0), **kwargs):
+ super().__init__(size, scale, **kwargs)
+
+ # def single_forward(self, img, i, j, h, w):
+ # if self.padding is not None:
+ # img = F2.pad(img, self.padding, self.fill, self.padding_mode)
+ # width, height = F2.get_image_size(img)
+ # # pad the width if needed
+ # if self.pad_if_needed and width < self.size[1]:
+ # padding = [self.size[1] - width, 0]
+ # img = F2.pad(img, padding, self.fill, self.padding_mode)
+ # # pad the height if needed
+ # if self.pad_if_needed and height < self.size[0]:
+ # padding = [0, self.size[0] - height]
+ # img = F2.pad(img, padding, self.fill, self.padding_mode)
+ #
+ # return F2.crop(img, i, j, h, w)
+
+ def single_forward(self, img, i, j, h, w, interpolation=InterpolationMode.BILINEAR):
+ """
+ Args:
+ img (PIL Image or Tensor): Image to be cropped and resized.
+
+ Returns:
+ PIL Image or Tensor: Randomly cropped and resized image.
+ """
+ # i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ return F2.resized_crop(img, i, j, h, w, self.size, interpolation)
+
+ def forward(self, img, target=None):
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if target is None:
+ img = self.single_forward(img, i, j, h, w)
+ return img
+ else:
+ img = self.single_forward(img, i, j, h, w)
+ target = self.single_forward(target, i, j, h, w, interpolation=InterpolationMode.NEAREST)
+ return img, target
+
+class ToTensor(T.ToTensor):
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, img, target=None):
+ if target is None:
+ img = F2.to_tensor(img)
+ return img
+ else:
+ img = F2.to_tensor(img)
+ target = F2.to_tensor(target)
+ return img, target
+
+class Lambda(T.Lambda):
+ """Apply a user-defined lambda as a transform. This transform does not support torchscript.
+
+ Args:
+ lambd (function): Lambda/function to be used for transform.
+ """
+
+ def __init__(self, lambd):
+ super().__init__(lambd)
+
+ def __call__(self, img, target=None):
+ if target is None:
+ return self.lambd(img)
+ else:
+ return self.lambd(img), self.lambd(target)
+
+class Compose(T.Compose):
+ def __init__(self, transforms):
+ super().__init__(transforms)
+
+ def __call__(self, img, target=None):
+ if target is None:
+ for t in self.transforms:
+ img = t(img)
+ return img
+ else:
+ for t in self.transforms:
+ img, target = t(img, target)
+ return img, target
+
+
+if __name__ == '__main__':
+ dataset = CIFAR10(
+ img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/cifar-10-python',
+ augment_horizontal_flip=False
+ )
+ # dataset = CityscapesDataset(
+ # # img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/CelebAHQ/celeba_hq_256',
+ # data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/Cityscapes/',
+ # # data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/ADEChallengeData2016/',
+ # image_size=[512, 1024],
+ # exts = ['png'],
+ # augment_horizontal_flip = False,
+ # convert_image_to = None,
+ # normalize_to_neg_one_to_one=True,
+ # )
+ # dataset = SRDataset(
+ # img_folder='/media/huang/ZX3 512G/data/DIV2K/DIV2K_train_HR',
+ # image_size=[512, 512],
+ # )
+ # dataset = InpaintDataset(
+ # img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/CelebAHQ/celeba_hq_256',
+ # image_size=[256, 256],
+ # augment_horizontal_flip = True
+ # )
+ dataset = EdgeDataset(
+ data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/BSDS',
+ image_size=[320, 320],
+ )
+ for i in range(len(dataset)):
+ d = dataset[i]
+ mask = d['cond']
+ print(mask.max())
+ dl = data.DataLoader(dataset, batch_size=2, shuffle=False, pin_memory=True, num_workers=0)
+
+
+ dataset_builder = tfds.builder('cifar10')
+ split = 'train'
+ dataset_options = tf.data.Options()
+ dataset_options.experimental_optimization.map_parallelization = True
+ dataset_options.experimental_threading.private_threadpool_size = 48
+ dataset_options.experimental_threading.max_intra_op_parallelism = 1
+ read_config = tfds.ReadConfig(options=dataset_options)
+ dataset_builder.download_and_prepare()
+ ds = dataset_builder.as_dataset(
+ split=split, shuffle_files=True, read_config=read_config)
+ pause = 0
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b733a2b00c4a3b750ccb35b0cbeabdc34ec0d0f
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py
@@ -0,0 +1,992 @@
+import torch
+import torch.nn as nn
+from torch.cuda.amp import custom_bwd, custom_fwd
+import math
+import torch.nn.functional as F
+# import torchvision.transforms.functional as F2
+from .utils import default, identity, normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
+from tqdm.auto import tqdm
+from einops import rearrange, reduce
+from functools import partial
+from collections import namedtuple
+from random import random, randint, sample, choice
+from .encoder_decoder import DiagonalGaussianDistribution
+import random
+from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import *
+
+# gaussian diffusion trainer class
+ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
+
+def extract(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def linear_beta_schedule(timesteps):
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
+
+def cosine_beta_schedule(timesteps, s = 0.008):
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ steps = timesteps + 1
+ x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0, 0.999)
+
+class DDPM(nn.Module):
+ def __init__(
+ self,
+ model,
+ *,
+ image_size,
+ timesteps = 1000,
+ sampling_timesteps = None,
+ loss_type = 'l2',
+ objective = 'pred_noise',
+ beta_schedule = 'cosine',
+ p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
+ p2_loss_weight_k = 1,
+ original_elbo_weight=0.,
+ ddim_sampling_eta = 1.,
+ clip_x_start=True,
+ train_sample=-1,
+ input_keys=['image'],
+ start_dist='normal',
+ sample_type='ddim',
+ perceptual_weight=1.,
+ use_l1=False,
+ **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ only_model = kwargs.pop("only_model", False)
+ cfg = kwargs.pop("cfg", None)
+ super().__init__(**kwargs)
+ # assert not (type(self) == DDPM and model.channels != model.out_dim)
+ # assert not model.random_or_learned_sinusoidal_cond
+
+ self.model = model
+ self.channels = self.model.channels
+ self.self_condition = self.model.self_condition
+ self.input_keys = input_keys
+ self.cfg = cfg
+ self.eps = cfg.get('eps', 1e-4) if cfg is not None else 1e-4
+ self.weighting_loss = cfg.get("weighting_loss", False) if cfg is not None else False
+ if self.weighting_loss:
+ print('#### WEIGHTING LOSS ####')
+
+ self.clip_x_start = clip_x_start
+ self.image_size = image_size
+ self.train_sample = train_sample
+ self.objective = objective
+ self.start_dist = start_dist
+ assert start_dist in ['normal', 'uniform']
+
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v', 'pred_delta', 'pred_KC'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
+
+ if beta_schedule == 'linear':
+ betas = linear_beta_schedule(timesteps)
+ elif beta_schedule == 'cosine':
+ betas = cosine_beta_schedule(timesteps, s=1e-4)
+ else:
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
+ # betas[0] = 2e-3 * betas[0]
+ alphas = 1. - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.time_range = list(range(self.num_timesteps + 1))
+ self.loss_type = loss_type
+ self.original_elbo_weight = original_elbo_weight
+
+ # sampling related parameters
+
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
+
+ # assert self.sampling_timesteps <= timesteps
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
+ self.ddim_sampling_eta = ddim_sampling_eta
+
+ # helper function to register buffer from float64 to float32
+
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
+
+ register_buffer('betas', betas)
+ register_buffer('alphas_cumprod', alphas_cumprod)
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+
+ register_buffer('posterior_variance', posterior_variance)
+
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
+
+ # calculate p2 reweighting
+
+ register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
+ assert not torch.isnan(self.p2_loss_weight).all()
+ if self.objective == "pred_noise":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * (self.posterior_variance+1e-5) * alphas * (1 - self.alphas_cumprod))
+ elif self.objective == "pred_x0":
+ lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod)
+ elif self.objective == "pred_delta":
+ lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod)
+ elif self.objective == "pred_KC":
+ lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod)
+ elif self.objective == "pred_v":
+ lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod)
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+ self.use_l1 = use_l1
+
+ self.perceptual_weight = perceptual_weight
+ if self.perceptual_weight > 0:
+ self.perceptual_loss = LPIPS().eval()
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys, only_model)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False, use_ema=False):
+ sd = torch.load(path, map_location="cpu")
+ if 'ema' in list(sd.keys()) and use_ema:
+ sd = sd['ema']
+ new_sd = {}
+ for k in sd.keys():
+ if k.startswith("ema_model."):
+ new_k = k[10:] # remove ema_model.
+ new_sd[new_k] = sd[k]
+ sd = new_sd
+ else:
+ if "model" in list(sd.keys()):
+ sd = sd["model"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def p_sample(self, x, mask, t: int, x_self_cond = None, clip_denoised = True):
+ b, *_, device = *x.shape, x.device
+ batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, mask=mask, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
+ return pred_img, x_start
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, mask, up_scale=1, unnormalize=True):
+ batch, device = shape[0], self.betas.device
+
+ img = torch.randn(shape, device=device)
+ img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True)
+
+ x_start = None
+
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
+ self_cond = x_start if self.self_condition else None
+ img, x_start = self.p_sample(img, mask, t, self_cond)
+ if unnormalize:
+ img = unnormalize_to_zero_to_one(img)
+ return img
+
+ @torch.no_grad()
+ def ddim_sample(self, shape, mask, up_scale=1, unnormalize=True):
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
+
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
+ times = list(reversed(times.int().tolist()))
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
+
+ img = torch.randn(shape, device = device)
+ img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True)
+
+ x_start = None
+
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total=len(time_pairs)):
+ time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
+ self_cond = x_start if self.self_condition else None
+ pred_noise, x_start, *_ = self.model_predictions(img, time_cond, mask, self_cond)
+
+ if time_next < 0:
+ img = x_start
+ continue
+
+ alpha = self.alphas_cumprod[time]
+ alpha_next = self.alphas_cumprod[time_next]
+
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
+ c = (1 - alpha_next - sigma ** 2).sqrt()
+
+ noise = torch.randn_like(img)
+
+ img = x_start * alpha_next.sqrt() + \
+ c * pred_noise + \
+ sigma * noise
+ if unnormalize:
+ img = unnormalize_to_zero_to_one(img)
+ return img
+
+
+ @torch.no_grad()
+ def interpolate(self, x1, x2, mask, t = None, lam = 0.5):
+ b, *_, device = *x1.shape, x1.device
+ t = default(t, self.num_timesteps - 1)
+
+ assert x1.shape == x2.shape
+
+ t_batched = torch.stack([torch.tensor(t, device = device)] * b)
+ xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
+
+ img = (1 - lam) * xt1 + lam * xt2
+ for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
+ img = self.p_sample(img, mask, torch.full((b,), i, device=device, dtype=torch.long))
+ return img
+
+ def get_input(self, batch, return_first_stage_outputs=False, return_original_cond=False):
+ assert 'image' in self.input_keys;
+ if len(self.input_keys) > len(batch.keys()):
+ x, *_ = batch.values()
+ else:
+ x = batch.values()
+ return x
+
+ def training_step(self, batch):
+ z, *_ = self.get_input(batch)
+ cond = batch['cond'] if 'cond' in batch else None
+ loss, loss_dict = self(z, cond)
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # continuous time, t in [0, 1]
+ # t = []
+ # for _ in range(x.shape[0]):
+ # if self.train_sample <= 0:
+ # t.append(torch.tensor(sample(self.time_range, 2), device=x.device).long())
+ # else:
+ # sl = choice(self.time_range)
+ # sl_range = list(range(sl - self.train_sample, sl + self.train_sample))
+ # sl_range = list(set(sl_range) & set(self.time_range))
+ # sl_range.pop(sl_range.index(sl))
+ # sl2 = choice(sl_range)
+ # t.append(torch.tensor([sl, sl2], device=x.device).long())
+ # t = torch.stack(t, dim=0)
+ # t = torch.randint(0, self.num_timesteps+1, (x.shape[0],), device=x.device).long()
+ eps = self.eps # smallest time step
+ # t = torch.rand((x.shape[0],), device=x.device) * (self.num_timesteps / eps)
+ # t = t.round() * eps
+ # t[t < eps] = eps
+ t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def q_sample2(self, x_start, t, noise=None):
+ b, c, h, w = x_start.shape
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ _, nt = t.shape
+ param_x = self.sqrt_alphas_cumprod.repeat(b, 1).gather(-1, t) # (b, nt)
+ x = x_start.expand(nt, b, c, h, w).transpose(1, 0) * param_x.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w)
+ param_noise = self.sqrt_one_minus_alphas_cumprod.repeat(b, 1).gather(-1, t)
+ n = noise.expand(nt, b, c, h, w).transpose(1, 0) * param_noise.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w)
+ return x + n # (b, nt, c, h, w)
+
+ def q_sample3(self, x_start, t, C):
+ b, c, h, w = x_start.shape
+ _, nt = t.shape
+ # K_ = K.unsqueeze(1).repeat(1, nt, 1, 1, 1)
+ C_ = C.unsqueeze(1).repeat(1, nt, 1, 1, 1)
+ x_noisy = x_start.expand(nt, b, c, h, w).transpose(1, 0) + \
+ + C_ * t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps
+ return x_noisy # (b, nt, c, h, w)
+
+ # def q_sample(self, x_start, t, C):
+ # x_noisy = x_start + C * t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) / self.num_timesteps
+ # return x_noisy
+ def q_sample(self, x_start, noise, t, C):
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ x_noisy = x_start + C * time + torch.sqrt(time) * noise
+ return x_noisy
+
+ def q_sample2(self, x_start, noise, t, C):
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ x_noisy = x_start + C / 2 * time ** 2 + torch.sqrt(time) * noise
+ return x_noisy
+
+ def pred_x0_from_xt(self, xt, noise, C, t):
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ x0 = xt - C * time - torch.sqrt(time) * noise
+ return x0
+
+ def pred_x0_from_xt2(self, xt, noise, C, t):
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ x0 = xt - C / 2 * time ** 2 - torch.sqrt(time) * noise
+ return x0
+
+ def pred_xtms_from_xt(self, xt, noise, C, t, s):
+ # noise = noise / noise.std(dim=[1, 2, 3]).reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ mean = xt + C * (time-s) - C * time - s / torch.sqrt(time) * noise
+ epsilon = torch.randn_like(mean, device=xt.device)
+ sigma = torch.sqrt(s * (time-s) / time)
+ xtms = mean + sigma * epsilon
+ return xtms
+
+ def pred_xtms_from_xt2(self, xt, noise, C, t, s):
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ mean = xt + C / 2 * (time-s) ** 2 - C / 2 * time ** 2 - s / torch.sqrt(time) * noise
+ epsilon = torch.randn_like(mean, device=xt.device)
+ sigma = torch.sqrt(s * (time-s) / time)
+ xtms = mean + sigma * epsilon
+ return xtms
+
+ def WCE_loss(self, prediction, labelf, beta=1.1):
+ label = labelf.long()
+ mask = labelf.clone()
+
+ num_positive = torch.sum(label == 1).float()
+ num_negative = torch.sum(label == 0).float()
+
+ mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative)
+ mask[label == 0] = beta * num_positive / (num_positive + num_negative)
+ mask[label == 2] = 0
+ cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='sum')
+
+ return cost
+
+ def Dice_Loss(self, pred, label):
+ # pred = torch.sigmoid(pred)
+ smooth = 1
+ pred_flat = pred.view(-1)
+ label_flat = label.view(-1)
+
+ intersecion = pred_flat * label_flat
+ unionsection = pred_flat.pow(2).sum() + label_flat.pow(2).sum() + smooth
+ loss = unionsection / (2 * intersecion.sum() + smooth)
+ loss = loss.sum()
+ return loss
+
+ def p_losses(self, x_start, t, *args, **kwargs):
+ if self.start_dist == 'normal':
+ noise = torch.randn_like(x_start)
+ elif self.start_dist == 'uniform':
+ noise = 2 * torch.rand_like(x_start) - 1.
+ else:
+ raise NotImplementedError(f'{self.start_dist} is not supported !')
+ # K = -1. * torch.ones_like(x_start)
+ # C = noise - x_start # t = 1000 / 1000
+ C = -1 * x_start # U(t) = Ct, U(1) = -x0
+ # C = -2 * x_start # U(t) = 1/2 * C * t**2, U(1) = 1/2 * C = -x0
+ x_noisy = self.q_sample(x_start=x_start, noise=noise, t=t, C=C) # (b, 2, c, h, w)
+ C_pred, noise_pred = self.model(x_noisy, t, **kwargs)
+ # C_pred = C_pred / torch.sqrt(t)
+ # noise_pred = noise_pred / torch.sqrt(1 - t)
+ x_rec = self.pred_x0_from_xt(x_noisy, noise_pred, C_pred, t) # x_rec:(B, 1, H, W)
+ loss_dict = {}
+ prefix = 'train'
+
+ # elif self.objective == 'pred_KC':
+ # target1 = C
+ # target2 = noise
+ # target3 = x_start
+
+ target1 = C
+ target2 = noise
+ target3 = x_start
+
+ loss_simple = 0.
+ loss_vlb = 0.
+ # use l1 + l2
+ if self.weighting_loss:
+ simple_weight1 = 2*torch.exp(1-t)
+ simple_weight2 = torch.exp(torch.sqrt(t))
+ if self.cfg.model_name == 'ncsnpp9':
+ simple_weight1 = (t + 1) / t.sqrt()
+ simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt()
+ else:
+ simple_weight1 = 1
+ simple_weight2 = 1
+
+ loss_simple += simple_weight1 * self.get_loss(C_pred, target1, mean=False).mean([1, 2, 3]) + \
+ simple_weight2 * self.get_loss(noise_pred, target2, mean=False).mean([1, 2, 3])
+ if self.use_l1:
+ loss_simple += simple_weight1 * (C_pred - target1).abs().mean([1, 2, 3]) + \
+ simple_weight2 * (noise_pred - target2).abs().mean([1, 2, 3])
+ loss_simple = loss_simple / 2
+ # rec_weight = (1 - t.reshape(C.shape[0], 1)) ** 2
+ rec_weight = 1 - t.reshape(C.shape[0], 1) # (B, 1)
+ loss_simple = loss_simple.mean()
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple})
+
+ # loss_vlb += torch.abs(x_rec - target3).mean([1, 2, 3]) * rec_weight: (B, 1)
+ loss_vlb += self.Dice_Loss(x_rec, target3)
+ loss_vlb = loss_vlb.mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + loss_vlb
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, up_scale=1, cond=None, denoise=True):
+ image_size, channels = self.image_size, self.channels
+ if cond is not None:
+ batch_size = cond.shape[0]
+ return self.sample_fn((batch_size, channels, image_size[0], image_size[1]),
+ up_scale=up_scale, unnormalize=True, cond=cond, denoise=denoise)
+
+ @torch.no_grad()
+ def sample_fn(self, shape, up_scale=1, unnormalize=True, cond=None, denoise=False):
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], \
+ self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
+
+ # times = torch.linspace(-1, total_timesteps, steps=self.sampling_timesteps + 1).int()
+ # times = list(reversed(times.int().tolist()))
+ # time_pairs = list(zip(times[:-1], times[1:]))
+ # time_steps = torch.tensor([0.25, 0.15, 0.1, 0.1, 0.1, 0.09, 0.075, 0.06, 0.045, 0.03])
+ step = 1. / self.sampling_timesteps
+ # time_steps = torch.tensor([0.1]).repeat(10)
+ time_steps = torch.tensor([step]).repeat(self.sampling_timesteps)
+ if denoise:
+ eps = self.eps
+ time_steps = torch.cat((time_steps[:-1], torch.tensor([step - eps]), torch.tensor([eps])), dim=0)
+
+ if self.start_dist == 'normal':
+ img = torch.randn(shape, device=device)
+ elif self.start_dist == 'uniform':
+ img = 2 * torch.rand(shape, device=device) - 1.
+ else:
+ raise NotImplementedError(f'{self.start_dist} is not supported !')
+ img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True)
+ # K = -1 * torch.ones_like(img)
+ cur_time = torch.ones((batch,), device=device)
+ for i, time_step in enumerate(time_steps):
+ s = torch.full((batch,), time_step, device=device)
+ if i == time_steps.shape[0] - 1:
+ s = cur_time
+ if cond is not None:
+ pred = self.model(img, cur_time, cond)
+ else:
+ pred = self.model(img, cur_time)
+ # C, noise = pred.chunk(2, dim=1)
+ C, noise = pred[:2]
+ # correct C
+ x0 = self.pred_x0_from_xt(img, noise, C, cur_time)
+ if self.clip_x_start:
+ x0.clamp_(-1., 1.)
+ # C.clamp_(-2., 2.)
+ C = -1 * x0
+ img = self.pred_xtms_from_xt(img, noise, C, cur_time, s)
+ # img = self.pred_xtms_from_xt2(img, noise, C, cur_time, s)
+ cur_time = cur_time - s
+ img.clamp_(-1., 1.)
+ if unnormalize:
+ img = unnormalize_to_zero_to_one(img)
+ return img
+
+
+
+class LatentDiffusion(DDPM):
+ def __init__(self,
+ auto_encoder,
+ scale_factor=1.0,
+ scale_by_std=True,
+ scale_by_softsign=False,
+ input_keys=['image'],
+ sample_type='ddim',
+ num_timesteps_cond=1,
+ train_sample=-1,
+ default_scale=False,
+ *args,
+ **kwargs
+ ):
+ self.scale_by_std = scale_by_std
+ self.scale_by_softsign = scale_by_softsign
+ self.default_scale = default_scale
+ self.num_timesteps_cond = num_timesteps_cond
+ self.train_sample = train_sample
+ self.perceptual_weight = 0
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ only_model = kwargs.pop("only_model", False)
+ super().__init__(*args, **kwargs)
+ assert self.num_timesteps_cond <= self.num_timesteps
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ if self.scale_by_softsign:
+ self.scale_by_std = False
+ print('### USING SOFTSIGN RESCALING')
+ assert (self.scale_by_std and self.scale_by_softsign) is False;
+
+ self.init_first_stage(auto_encoder)
+ # self.instantiate_cond_stage(cond_stage_config)
+ self.input_keys = input_keys
+ self.clip_denoised = False
+ assert sample_type in ['p_loop', 'ddim', 'dpm', 'transformer'] ### 'dpm' is not availible now, suggestion 'ddim'
+ self.sample_type = sample_type
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys, only_model)
+
+ def init_first_stage(self, first_stage_model):
+ self.first_stage_model = first_stage_model.eval()
+ # self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ '''
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+ '''
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ # return self.scale_factor * z.detach() + self.scale_bias
+ return z.detach()
+
+ @torch.no_grad()
+ def on_train_batch_start(self, batch):
+ # only for the first batch
+ if self.scale_by_std and (not self.scale_by_softsign):
+ if not self.default_scale:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x, *_ = batch.values()
+ encoder_posterior = self.first_stage_model.encode(x)
+ z = self.get_first_stage_encoding(encoder_posterior)
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ # print("### USING STD-RESCALING ###")
+ else:
+ print(f'### USING DEFAULT SCALE {self.scale_factor}')
+ else:
+ print(f'### USING SOFTSIGN SCALE !')
+
+ @torch.no_grad()
+ def get_input(self, batch, return_first_stage_outputs=False, return_original_cond=False):
+ assert 'image' in self.input_keys;
+ # if len(self.input_keys) > len(batch.keys()):
+ # x, cond, *_ = batch.values()
+ # else:
+ # x, cond = batch.values()
+ x = batch['image']
+ cond = batch['cond'] if 'cond' in batch else None
+ z = self.first_stage_model.encode(x)
+ # print('zzzz', z.shape)
+ z = self.get_first_stage_encoding(z)
+ out = [z, cond, x]
+ if return_first_stage_outputs:
+ xrec = self.first_stage_model.decode(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(cond)
+ return out
+
+ def training_step(self, batch):
+ z, c, *_ = self.get_input(batch)
+ # print(_[0].shape)
+ if self.scale_by_softsign:
+ z = F.softsign(z)
+ elif self.scale_by_std:
+ z = self.scale_factor * z
+ # print('grad', self.scale_bias.grad)
+ loss, loss_dict = self(z, c, edge=_[0])
+ return loss, loss_dict
+
+ def q_sample3(self, x_start, t, K, C):
+ b, c, h, w = x_start.shape
+ _, nt = t.shape
+ K_ = K.unsqueeze(1).repeat(1, nt, 1, 1, 1)
+ C_ = C.unsqueeze(1).repeat(1, nt, 1, 1, 1)
+ x_noisy = x_start.expand(nt, b, c, h, w).transpose(1, 0) + K_ / 2 * (t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps) ** 2 \
+ + C_ * t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps
+ return x_noisy # (b, nt, c, h, w)
+
+ def pred_xtms_from_xt(self, xt, noise, C, t, s):
+ # noise = noise / noise.std(dim=[1, 2, 3]).reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1)))
+ mean = xt - C * s - s / torch.sqrt(time) * noise
+ epsilon = torch.randn_like(mean, device=xt.device)
+ sigma = torch.sqrt(s * (time-s) / time)
+ xtms = mean + sigma * epsilon
+ return xtms
+
+ def WCE_loss(self, prediction, labelf, beta=1.1):
+ label = labelf.long()
+ mask = labelf.clone()
+
+ num_positive = torch.sum(label == 1).float()
+ num_negative = torch.sum(label == 0).float()
+
+ mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative)
+ mask[label == 0] = beta * num_positive / (num_positive + num_negative)
+ mask[label == 2] = 0
+ cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='sum')
+
+ return cost
+
+ def Dice_Loss(self, pred, label):
+ # pred = torch.sigmoid(pred)
+ B = pred.shape[0]
+ smooth = 1
+ pred_flat = pred.view(B, -1)
+ label_flat = label.view(B, -1)
+
+ intersecion = pred_flat * label_flat
+ unionsection = pred_flat.pow(2).sum(dim=-1) + label_flat.pow(2).sum(dim=-1) + smooth
+ loss = unionsection / (2 * intersecion.sum(dim=-1) + smooth)
+ loss = loss.reshape(B, 1)
+ return loss
+
+ def p_losses(self, x_start, t, *args, **kwargs):
+ if self.start_dist == 'normal':
+ noise = torch.randn_like(x_start)
+ elif self.start_dist == 'uniform':
+ noise = 2 * torch.rand_like(x_start) - 1.
+ else:
+ raise NotImplementedError(f'{self.start_dist} is not supported !')
+ # K = -1. * torch.ones_like(x_start)
+ # C = noise - x_start # t = 1000 / 1000
+ C = -1 * x_start # U(t) = Ct, U(1) = -x0
+ # C = -2 * x_start # U(t) = 1/2 * C * t**2, U(1) = 1/2 * C = -x0
+ x_noisy = self.q_sample(x_start=x_start, noise=noise, t=t, C=C) # (b, 2, c, h, w)
+ if self.cfg.model_name == 'cond_unet8':
+ C_pred, noise_pred, (e1, e2) = self.model(x_noisy, t, *args, **kwargs)
+ if self.cfg.model_name == 'cond_unet13':
+ C_pred, noise_pred, aux_C = self.model(x_noisy, t, *args, **kwargs)
+ else:
+ C_pred, noise_pred = self.model(x_noisy, t, *args, **kwargs)
+ # C_pred = C_pred / torch.sqrt(t)
+ # noise_pred = noise_pred / torch.sqrt(1 - t)
+ x_rec = self.pred_x0_from_xt(x_noisy, noise_pred, C_pred, t) # x_rec:(B, C, H, W)
+ loss_dict = {}
+ prefix = 'train'
+
+ # elif self.objective == 'pred_KC':
+ # target1 = C
+ # target2 = noise
+ # target3 = x_start
+
+ target1 = C
+ target2 = noise
+ target3 = x_start
+
+ loss_simple = 0.
+ loss_vlb = 0.
+
+ simple_weight1 = (t + 1) / t.sqrt()
+ simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt()
+
+ # if self.weighting_loss:
+ # simple_weight1 = 2 * torch.exp(1 - t)
+ # simple_weight2 = torch.exp(torch.sqrt(t))
+ # if self.cfg.model_name == 'ncsnpp9':
+ # simple_weight1 = (t + 1) / t.sqrt()
+ # simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt()
+ # else:
+ # simple_weight1 = 1
+ # simple_weight2 = 1
+
+ loss_simple += simple_weight1 * self.get_loss(C_pred, target1, mean=False).mean([1, 2, 3]) + \
+ simple_weight2 * self.get_loss(noise_pred, target2, mean=False).mean([1, 2, 3])
+
+ # loss_simple += self.Dice_Loss(C_pred, target1) * simple_weight1
+
+ if self.use_l1:
+ loss_simple += simple_weight1 * (C_pred - target1).abs().mean([1, 2, 3]) + \
+ simple_weight2 * (noise_pred - target2).abs().mean([1, 2, 3])
+ loss_simple = loss_simple / 2
+
+ if self.cfg.model_name == 'cond_unet8':
+ loss_simple += 0.05*(self.Dice_Loss(e1, (kwargs['edge'] + 1)/2) + self.Dice_Loss(e2, (kwargs['edge'] + 1)/2))
+ elif self.cfg.model_name == 'cond_unet13':
+ loss_simple += 0.5 * (simple_weight1 * self.get_loss(aux_C, target1, mean=False).mean([1, 2, 3]) + \
+ simple_weight1 * (aux_C - target1).abs().mean([1, 2, 3]))
+
+ rec_weight = (1 - t.reshape(C.shape[0], 1)) ** 2
+ # rec_weight = 1 - t.reshape(C.shape[0], 1) # (B, 1)
+ loss_simple = loss_simple.mean()
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple})
+
+ loss_vlb += torch.abs(x_rec - target3).mean([1, 2, 3]) * rec_weight # : (B, 1)
+ # loss_vlb += self.Dice_Loss(x_rec, target3) * rec_weight
+
+ # loss_vlb = loss_vlb
+ loss_vlb = loss_vlb.mean()
+
+ if self.cfg.get('use_disloss', False):
+ with torch.no_grad():
+ edge_rec = self.first_stage_model.decode(x_rec / self.scale_factor)
+ edge_rec = unnormalize_to_zero_to_one(edge_rec)
+ edge_rec = torch.clamp(edge_rec, min=0., max=1.) # B, 1, 320, 320
+ loss_tmp = self.cross_entropy_loss_RCF(edge_rec, (kwargs['edge'] + 1)/2) * rec_weight # B, 1
+ loss_ce = SpecifyGradient.apply(x_rec, loss_tmp.mean())
+ # print(loss_ce.shape)
+ # print(loss_vlb.shape)
+ loss_vlb += loss_ce.mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + loss_vlb
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def cross_entropy_loss_RCF(self, prediction, labelf, beta=1.1):
+ # label = labelf.long()
+ label = labelf
+ mask = labelf.clone()
+
+ num_positive = torch.sum(label == 1).float()
+ num_negative = torch.sum(label == 0).float()
+
+ mask_temp = (label > 0) & (label <= 0.3)
+ mask[mask_temp] = 0.
+
+ mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative)
+ mask[label == 0] = beta * num_positive / (num_positive + num_negative)
+
+ # mask[label == 2] = 0
+ cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='none')
+ return cost.mean([1, 2, 3])
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, up_scale=1, cond=None, mask=None, denoise=True):
+ # image_size, channels = self.image_size, self.channels
+ channels = self.channels
+ image_size = cond.shape[-2:]
+ if cond is not None:
+ batch_size = cond.shape[0]
+ down_ratio = self.first_stage_model.down_ratio
+ if self.cfg.model_name == 'cond_unet8' or self.cfg.model_name == 'cond_unet13':
+ z, aux_out = self.sample_fn((batch_size, channels, image_size[0] // down_ratio, image_size[1] // down_ratio),
+ up_scale=up_scale, unnormalize=False, cond=cond, denoise=denoise)
+ else:
+ z = self.sample_fn((batch_size, channels, image_size[0]//down_ratio, image_size[1]//down_ratio),
+ up_scale=up_scale, unnormalize=False, cond=cond, denoise=denoise)
+ aux_out = None
+
+ if self.scale_by_std:
+ z = 1. / self.scale_factor * z.detach()
+ if self.cfg.model_name == 'cond_unet13':
+ aux_out = 1. / self.scale_factor * aux_out.detach()
+ elif self.scale_by_softsign:
+ z = z / (1 - z.abs())
+ z = z.detach()
+ #print(z.shape)
+ x_rec = self.first_stage_model.decode(z)
+ x_rec = unnormalize_to_zero_to_one(x_rec)
+ x_rec = torch.clamp(x_rec, min=0., max=1.)
+ if self.cfg.model_name == 'cond_unet13':
+ aux_out = self.first_stage_model.decode(aux_out)
+ aux_out = unnormalize_to_zero_to_one(aux_out)
+ aux_out = torch.clamp(aux_out, min=0., max=1.)
+ if mask is not None:
+ x_rec = mask * unnormalize_to_zero_to_one(cond) + (1 - mask) * x_rec
+ if aux_out is not None:
+ return x_rec, aux_out
+ return x_rec
+
+ @torch.no_grad()
+ def sample_fn(self, shape, up_scale=1, unnormalize=True, cond=None, denoise=False):
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], \
+ self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
+
+ # times = torch.linspace(-1, total_timesteps, steps=self.sampling_timesteps + 1).int()
+ # times = list(reversed(times.int().tolist()))
+ # time_pairs = list(zip(times[:-1], times[1:]))
+ # time_steps = torch.tensor([0.25, 0.15, 0.1, 0.1, 0.1, 0.09, 0.075, 0.06, 0.045, 0.03])
+ step = 1. / self.sampling_timesteps
+ # time_steps = torch.tensor([0.1]).repeat(10)
+ time_steps = torch.tensor([step]).repeat(self.sampling_timesteps)
+ if denoise:
+ eps = self.eps
+ time_steps = torch.cat((time_steps[:-1], torch.tensor([step - eps]), torch.tensor([eps])), dim=0)
+
+ if self.start_dist == 'normal':
+ img = torch.randn(shape, device=device)
+ elif self.start_dist == 'uniform':
+ img = 2 * torch.rand(shape, device=device) - 1.
+ else:
+ raise NotImplementedError(f'{self.start_dist} is not supported !')
+ img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True)
+ img_aux = F.interpolate(img.clone(), scale_factor=up_scale, mode='bilinear', align_corners=True)
+ # img_aux = img.clone()
+ # K = -1 * torch.ones_like(img)
+ cur_time = torch.ones((batch,), device=device)
+ for i, time_step in enumerate(time_steps):
+ s = torch.full((batch,), time_step, device=device)
+ if i == time_steps.shape[0] - 1:
+ s = cur_time
+ if cond is not None:
+ pred = self.model(img, cur_time, cond)
+ else:
+ pred = self.model(img, cur_time)
+ # C, noise = pred.chunk(2, dim=1)
+ C, noise = pred[:2]
+ if self.cfg.model_name == 'cond_unet8' or self.cfg.model_name == 'cond_unet13':
+ aux_out = pred[-1]
+ else:
+ aux_out = None
+ # if self.scale_by_softsign:
+ # # correct the C for softsign
+ # x0 = self.pred_x0_from_xt(img, noise, C, cur_time)
+ # x0 = torch.clamp(x0, min=-0.987654321, max=0.987654321)
+ # C = -x0
+ # correct C
+ x0 = self.pred_x0_from_xt(img, noise, C, cur_time)
+ C = -1 * x0
+ img = self.pred_xtms_from_xt(img, noise, C, cur_time, s)
+ # if self.cfg.model_name == 'cond_unet13' and i == len(time_steps) - 2:
+ # img_aux = img
+ # if self.cfg.model_name == 'cond_unet13' and i in [len(time_steps)-2, len(time_steps)-1]:
+ # x0_aux = self.pred_x0_from_xt(img_aux, noise, aux_out, cur_time)
+ # C_aux = -1 * x0_aux
+ # img_aux = self.pred_xtms_from_xt(img_aux, noise, C_aux, cur_time, s)
+ if self.cfg.model_name == 'cond_unet13':
+ for _ in range(1):
+ x0_aux = self.pred_x0_from_xt(img_aux, noise, aux_out, cur_time)
+ C_aux = -1 * x0_aux
+ img_aux = self.pred_xtms_from_xt(img_aux, noise, C_aux, cur_time, s)
+ cur_time = cur_time - s
+ if self.scale_by_softsign:
+ img.clamp_(-0.987654321, 0.987654321)
+ if unnormalize:
+ img = unnormalize_to_zero_to_one(img)
+ if self.cfg.model_name == 'cond_unet13':
+ aux_out = img_aux
+ if aux_out is not None:
+ return img, aux_out
+ return img
+
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones(input_tensor.shape, device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ (gt_grad,) = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+if __name__ == "__main__":
+ ddconfig = {'double_z': True,
+ 'z_channels': 4,
+ 'resolution': (240, 960),
+ 'in_channels': 3,
+ 'out_ch': 3,
+ 'ch': 128,
+ 'ch_mult': [1, 2, 4, 4], # num_down = len(ch_mult)-1
+ 'num_res_blocks': 2,
+ 'attn_resolutions': [],
+ 'dropout': 0.0}
+ lossconfig = {'disc_start': 50001,
+ 'kl_weight': 0.000001,
+ 'disc_weight': 0.5}
+ from encoder_decoder import AutoencoderKL
+ auto_encoder = AutoencoderKL(ddconfig, lossconfig, embed_dim=4,
+ )
+ from mask_cond_unet import Unet
+ unet = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=4, cond_in_dim=1,)
+ ldm = LatentDiffusion(auto_encoder=auto_encoder, model=unet, image_size=ddconfig['resolution'])
+ image = torch.rand(1, 3, 128, 128)
+ mask = torch.rand(1, 1, 128, 128)
+ input = {'image': image, 'cond': mask}
+ time = torch.tensor([1])
+ with torch.no_grad():
+ y = ldm.training_step(input)
+ pass
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee49d7ab757febe8ef840020e0b4ae67f7311422
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py
@@ -0,0 +1,1130 @@
+import copy
+import math
+import warnings
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Dict, Optional, List, Sequence, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+from torchvision.ops import StochasticDepth
+
+from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
+from torchvision.transforms._presets import ImageClassification, InterpolationMode
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import WeightsEnum, Weights
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
+
+
+__all__ = [
+ "EfficientNet",
+ "EfficientNet_B0_Weights",
+ "EfficientNet_B1_Weights",
+ "EfficientNet_B2_Weights",
+ "EfficientNet_B3_Weights",
+ "EfficientNet_B4_Weights",
+ "EfficientNet_B5_Weights",
+ "EfficientNet_B6_Weights",
+ "EfficientNet_B7_Weights",
+ "EfficientNet_V2_S_Weights",
+ "EfficientNet_V2_M_Weights",
+ "EfficientNet_V2_L_Weights",
+ "efficientnet_b0",
+ "efficientnet_b1",
+ "efficientnet_b2",
+ "efficientnet_b3",
+ "efficientnet_b4",
+ "efficientnet_b5",
+ "efficientnet_b6",
+ "efficientnet_b7",
+ "efficientnet_v2_s",
+ "efficientnet_v2_m",
+ "efficientnet_v2_l",
+]
+
+
+@dataclass
+class _MBConvConfig:
+ expand_ratio: float
+ kernel: int
+ stride: int
+ input_channels: int
+ out_channels: int
+ num_layers: int
+ block: Callable[..., nn.Module]
+
+ @staticmethod
+ def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
+ return _make_divisible(channels * width_mult, 8, min_value)
+
+
+class MBConvConfig(_MBConvConfig):
+ # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
+ def __init__(
+ self,
+ expand_ratio: float,
+ kernel: int,
+ stride: int,
+ input_channels: int,
+ out_channels: int,
+ num_layers: int,
+ width_mult: float = 1.0,
+ depth_mult: float = 1.0,
+ block: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ input_channels = self.adjust_channels(input_channels, width_mult)
+ out_channels = self.adjust_channels(out_channels, width_mult)
+ num_layers = self.adjust_depth(num_layers, depth_mult)
+ if block is None:
+ block = MBConv
+ super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
+
+ @staticmethod
+ def adjust_depth(num_layers: int, depth_mult: float):
+ return int(math.ceil(num_layers * depth_mult))
+
+
+class FusedMBConvConfig(_MBConvConfig):
+ # Stores information listed at Table 4 of the EfficientNetV2 paper
+ def __init__(
+ self,
+ expand_ratio: float,
+ kernel: int,
+ stride: int,
+ input_channels: int,
+ out_channels: int,
+ num_layers: int,
+ block: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ if block is None:
+ block = FusedMBConv
+ super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
+
+
+class MBConv(nn.Module):
+ def __init__(
+ self,
+ cnf: MBConvConfig,
+ stochastic_depth_prob: float,
+ norm_layer: Callable[..., nn.Module],
+ se_layer: Callable[..., nn.Module] = SqueezeExcitation,
+ ) -> None:
+ super().__init__()
+
+ if not (1 <= cnf.stride <= 2):
+ raise ValueError("illegal stride value")
+
+ self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
+
+ layers: List[nn.Module] = []
+ activation_layer = nn.SiLU
+
+ # expand
+ expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
+ if expanded_channels != cnf.input_channels:
+ layers.append(
+ Conv2dNormActivation(
+ cnf.input_channels,
+ expanded_channels,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ )
+ )
+
+ # depthwise
+ layers.append(
+ Conv2dNormActivation(
+ expanded_channels,
+ expanded_channels,
+ kernel_size=cnf.kernel,
+ stride=cnf.stride,
+ groups=expanded_channels,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ )
+ )
+
+ # squeeze and excitation
+ squeeze_channels = max(1, cnf.input_channels // 4)
+ layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
+
+ # project
+ layers.append(
+ Conv2dNormActivation(
+ expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
+ )
+ )
+
+ self.block = nn.Sequential(*layers)
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+ self.out_channels = cnf.out_channels
+
+ def forward(self, input: Tensor) -> Tensor:
+ result = self.block(input)
+ if self.use_res_connect:
+ result = self.stochastic_depth(result)
+ result += input
+ return result
+
+
+class FusedMBConv(nn.Module):
+ def __init__(
+ self,
+ cnf: FusedMBConvConfig,
+ stochastic_depth_prob: float,
+ norm_layer: Callable[..., nn.Module],
+ ) -> None:
+ super().__init__()
+
+ if not (1 <= cnf.stride <= 2):
+ raise ValueError("illegal stride value")
+
+ self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
+
+ layers: List[nn.Module] = []
+ activation_layer = nn.SiLU
+
+ expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
+ if expanded_channels != cnf.input_channels:
+ # fused expand
+ layers.append(
+ Conv2dNormActivation(
+ cnf.input_channels,
+ expanded_channels,
+ kernel_size=cnf.kernel,
+ stride=cnf.stride,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ )
+ )
+
+ # project
+ layers.append(
+ Conv2dNormActivation(
+ expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
+ )
+ )
+ else:
+ layers.append(
+ Conv2dNormActivation(
+ cnf.input_channels,
+ cnf.out_channels,
+ kernel_size=cnf.kernel,
+ stride=cnf.stride,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ )
+ )
+
+ self.block = nn.Sequential(*layers)
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+ self.out_channels = cnf.out_channels
+
+ def forward(self, input: Tensor) -> Tensor:
+ result = self.block(input)
+ if self.use_res_connect:
+ result = self.stochastic_depth(result)
+ result += input
+ return result
+
+
+class EfficientNet(nn.Module):
+ def __init__(
+ self,
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
+ dropout: float,
+ stochastic_depth_prob: float = 0.2,
+ num_classes: int = 1000,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ last_channel: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ """
+ EfficientNet V1 and V2 main class
+
+ Args:
+ inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
+ dropout (float): The droupout probability
+ stochastic_depth_prob (float): The stochastic depth probability
+ num_classes (int): Number of classes
+ norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
+ last_channel (int): The number of channels on the penultimate layer
+ """
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if not inverted_residual_setting:
+ raise ValueError("The inverted_residual_setting should not be empty")
+ elif not (
+ isinstance(inverted_residual_setting, Sequence)
+ and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
+ ):
+ raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
+
+ if "block" in kwargs:
+ warnings.warn(
+ "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. "
+ "Please pass this information on 'MBConvConfig.block' instead."
+ )
+ if kwargs["block"] is not None:
+ for s in inverted_residual_setting:
+ if isinstance(s, MBConvConfig):
+ s.block = kwargs["block"]
+
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ layers: List[nn.Module] = []
+
+ # building first layer
+ firstconv_output_channels = inverted_residual_setting[0].input_channels
+ # layers.append(
+ # Conv2dNormActivation(
+ # 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
+ # )
+ # )
+ self.first_coonv = Conv2dNormActivation(
+ 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
+ )
+
+ # building inverted residual blocks
+ total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
+ stage_block_id = 0
+ for cnf in inverted_residual_setting:
+ stage: List[nn.Module] = []
+ for _ in range(cnf.num_layers):
+ # copy to avoid modifications. shallow copy is enough
+ block_cnf = copy.copy(cnf)
+
+ # overwrite info if not the first conv in the stage
+ if stage:
+ block_cnf.input_channels = block_cnf.out_channels
+ block_cnf.stride = 1
+
+ # adjust stochastic depth probability based on the depth of the stage block
+ sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
+
+ stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
+ stage_block_id += 1
+
+ layers.append(nn.Sequential(*stage))
+
+ # building last several layers
+ lastconv_input_channels = inverted_residual_setting[-1].out_channels
+ lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
+ layers.append(
+ Conv2dNormActivation(
+ lastconv_input_channels,
+ lastconv_output_channels,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_layer=nn.SiLU,
+ )
+ )
+ # self.last_conv = Conv2dNormActivation(
+ # lastconv_input_channels,
+ # lastconv_output_channels,
+ # kernel_size=1,
+ # norm_layer=norm_layer,
+ # activation_layer=nn.SiLU,
+ # )
+
+ # self.features = nn.Sequential(*layers)
+ self.features = nn.ModuleList(layers)
+ # self.avgpool = nn.AdaptiveAvgPool2d(1)
+ # self.classifier = nn.Sequential(
+ # nn.Dropout(p=dropout, inplace=True),
+ # nn.Linear(lastconv_output_channels, num_classes),
+ # )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ init_range = 1.0 / math.sqrt(m.out_features)
+ nn.init.uniform_(m.weight, -init_range, init_range)
+ nn.init.zeros_(m.bias)
+
+ def _forward_impl(self, x: Tensor):
+ x = self.first_coonv(x)
+ # x = self.features(x)
+ feats = []
+ for i, layer in enumerate(self.features):
+ x = layer(x)
+ if i in [1, 2, 4, 6]:
+ feats.append(x)
+
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ #
+ # x = self.classifier(x)
+
+ return feats
+
+ def forward(self, x: Tensor):
+ return self._forward_impl(x)
+
+
+def _efficientnet(
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
+ dropout: float,
+ last_channel: Optional[int],
+ weights: Optional[WeightsEnum],
+ progress: bool,
+ **kwargs: Any,
+) -> EfficientNet:
+ if weights is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+ model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
+
+ if weights is not None:
+ ckpt1 = weights.get_state_dict(progress=progress)
+ ckpt2 = model.state_dict()
+ kl1 = list(ckpt1.keys())
+ for i, k in enumerate(list(ckpt2.keys())):
+ ckpt2[k] = ckpt1[kl1[i]]
+ msg = model.load_state_dict(ckpt2, strict=False)
+ print(f'Load EfficientNet: {msg}')
+ else:
+ print('No pretrained weight loaded!')
+ return model
+
+
+def _efficientnet_conf(
+ arch: str,
+ **kwargs: Any,
+) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
+ if arch.startswith("efficientnet_b"):
+ bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
+ inverted_residual_setting = [
+ bneck_conf(1, 3, 1, 32, 16, 1),
+ bneck_conf(6, 3, 2, 16, 24, 2),
+ bneck_conf(6, 5, 2, 24, 40, 2),
+ bneck_conf(6, 3, 2, 40, 80, 3),
+ bneck_conf(6, 5, 1, 80, 112, 3),
+ bneck_conf(6, 5, 2, 112, 192, 4),
+ bneck_conf(6, 3, 1, 192, 320, 1),
+ ]
+ last_channel = None
+ elif arch.startswith("efficientnet_v2_s"):
+ inverted_residual_setting = [
+ FusedMBConvConfig(1, 3, 1, 24, 24, 2),
+ FusedMBConvConfig(4, 3, 2, 24, 48, 4),
+ FusedMBConvConfig(4, 3, 2, 48, 64, 4),
+ MBConvConfig(4, 3, 2, 64, 128, 6),
+ MBConvConfig(6, 3, 1, 128, 160, 9),
+ MBConvConfig(6, 3, 2, 160, 256, 15),
+ ]
+ last_channel = 1280
+ elif arch.startswith("efficientnet_v2_m"):
+ inverted_residual_setting = [
+ FusedMBConvConfig(1, 3, 1, 24, 24, 3),
+ FusedMBConvConfig(4, 3, 2, 24, 48, 5),
+ FusedMBConvConfig(4, 3, 2, 48, 80, 5),
+ MBConvConfig(4, 3, 2, 80, 160, 7),
+ MBConvConfig(6, 3, 1, 160, 176, 14),
+ MBConvConfig(6, 3, 2, 176, 304, 18),
+ MBConvConfig(6, 3, 1, 304, 512, 5),
+ ]
+ last_channel = 1280
+ elif arch.startswith("efficientnet_v2_l"):
+ inverted_residual_setting = [
+ FusedMBConvConfig(1, 3, 1, 32, 32, 4),
+ FusedMBConvConfig(4, 3, 2, 32, 64, 7),
+ FusedMBConvConfig(4, 3, 2, 64, 96, 7),
+ MBConvConfig(4, 3, 2, 96, 192, 10),
+ MBConvConfig(6, 3, 1, 192, 224, 19),
+ MBConvConfig(6, 3, 2, 224, 384, 25),
+ MBConvConfig(6, 3, 1, 384, 640, 7),
+ ]
+ last_channel = 1280
+ else:
+ raise ValueError(f"Unsupported model type {arch}")
+
+ return inverted_residual_setting, last_channel
+
+
+_COMMON_META: Dict[str, Any] = {
+ "categories": _IMAGENET_CATEGORIES,
+}
+
+
+_COMMON_META_V1 = {
+ **_COMMON_META,
+ "min_size": (1, 1),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
+}
+
+
+_COMMON_META_V2 = {
+ **_COMMON_META,
+ "min_size": (33, 33),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
+}
+
+
+class EfficientNet_B0_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
+ url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
+ transforms=partial(
+ ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 5288548,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 77.692,
+ "acc@5": 93.532,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B1_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
+ url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
+ transforms=partial(
+ ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 7794184,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 78.642,
+ "acc@5": 94.186,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
+ transforms=partial(
+ ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 7794184,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 79.838,
+ "acc@5": 94.934,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
+ `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class EfficientNet_B2_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
+ url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
+ transforms=partial(
+ ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 9109994,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 80.608,
+ "acc@5": 95.310,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B3_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
+ url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
+ transforms=partial(
+ ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 12233232,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 82.008,
+ "acc@5": 96.054,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B4_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
+ url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
+ transforms=partial(
+ ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 19341616,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 83.384,
+ "acc@5": 96.594,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B5_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+ url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
+ transforms=partial(
+ ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 30389784,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 83.444,
+ "acc@5": 96.628,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B6_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+ url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
+ transforms=partial(
+ ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 43040704,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 84.008,
+ "acc@5": 96.916,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B7_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+ url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
+ transforms=partial(
+ ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META_V1,
+ "num_params": 66347960,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 84.122,
+ "acc@5": 96.908,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_S_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
+ transforms=partial(
+ ImageClassification,
+ crop_size=384,
+ resize_size=384,
+ interpolation=InterpolationMode.BILINEAR,
+ ),
+ meta={
+ **_COMMON_META_V2,
+ "num_params": 21458488,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 84.228,
+ "acc@5": 96.878,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
+ `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_M_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
+ transforms=partial(
+ ImageClassification,
+ crop_size=480,
+ resize_size=480,
+ interpolation=InterpolationMode.BILINEAR,
+ ),
+ meta={
+ **_COMMON_META_V2,
+ "num_params": 54139356,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 85.112,
+ "acc@5": 97.156,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
+ `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_L_Weights(WeightsEnum):
+ # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
+ transforms=partial(
+ ImageClassification,
+ crop_size=480,
+ resize_size=480,
+ interpolation=InterpolationMode.BICUBIC,
+ mean=(0.5, 0.5, 0.5),
+ std=(0.5, 0.5, 0.5),
+ ),
+ meta={
+ **_COMMON_META_V2,
+ "num_params": 118515272,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 85.808,
+ "acc@5": 97.788,
+ }
+ },
+ "_docs": """These weights are ported from the original paper.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
+def efficientnet_b0(
+ *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B0_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B0_Weights
+ :members:
+ """
+ weights = EfficientNet_B0_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
+ return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
+def efficientnet_b1(
+ *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B1_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B1_Weights
+ :members:
+ """
+ weights = EfficientNet_B1_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
+ return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
+def efficientnet_b2(
+ *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B2_Weights
+ :members:
+ """
+ weights = EfficientNet_B2_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
+ return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
+def efficientnet_b3(
+ *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B3_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B3_Weights
+ :members:
+ """
+ weights = EfficientNet_B3_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
+ return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
+def efficientnet_b4(
+ *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B4_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B4_Weights
+ :members:
+ """
+ weights = EfficientNet_B4_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
+ return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
+def efficientnet_b5(
+ *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B5_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B5_Weights
+ :members:
+ """
+ weights = EfficientNet_B5_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.4,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+ **kwargs,
+ )
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
+def efficientnet_b6(
+ *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B6_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B6_Weights
+ :members:
+ """
+ weights = EfficientNet_B6_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.5,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+ **kwargs,
+ )
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
+def efficientnet_b7(
+ *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+ Neural Networks `_ paper.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_B7_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_B7_Weights
+ :members:
+ """
+ weights = EfficientNet_B7_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.5,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+ **kwargs,
+ )
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
+def efficientnet_v2_s(
+ *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """
+ Constructs an EfficientNetV2-S architecture from
+ `EfficientNetV2: Smaller Models and Faster Training `_.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
+ :members:
+ """
+ weights = EfficientNet_V2_S_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.2,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+ **kwargs,
+ )
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
+def efficientnet_v2_m(
+ *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """
+ Constructs an EfficientNetV2-M architecture from
+ `EfficientNetV2: Smaller Models and Faster Training `_.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
+ :members:
+ """
+ weights = EfficientNet_V2_M_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.3,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+ **kwargs,
+ )
+
+
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
+def efficientnet_v2_l(
+ *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+ """
+ Constructs an EfficientNetV2-L architecture from
+ `EfficientNetV2: Smaller Models and Faster Training `_.
+
+ Args:
+ weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
+ :members:
+ """
+ weights = EfficientNet_V2_L_Weights.verify(weights)
+
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
+ return _efficientnet(
+ inverted_residual_setting,
+ 0.4,
+ last_channel,
+ weights,
+ progress,
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+ **kwargs,
+ )
+
+
+# The dictionary below is internal implementation detail and will be removed in v0.15
+from torchvision.models._utils import _ModelURLs
+
+
+model_urls = _ModelURLs(
+ {
+ "efficientnet_b0": EfficientNet_B0_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b1": EfficientNet_B1_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b2": EfficientNet_B2_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b3": EfficientNet_B3_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b4": EfficientNet_B4_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b5": EfficientNet_B5_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b6": EfficientNet_B6_Weights.IMAGENET1K_V1.url,
+ "efficientnet_b7": EfficientNet_B7_Weights.IMAGENET1K_V1.url,
+ }
+)
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..0267f9d92635e19d2f32e76d53ff1a22227eb025
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py
@@ -0,0 +1,191 @@
+import copy
+import torch
+from torch import nn
+
+
+def exists(val):
+ return val is not None
+
+
+def clamp(value, min_value=None, max_value=None):
+ assert exists(min_value) or exists(max_value)
+ if exists(min_value):
+ value = max(value, min_value)
+
+ if exists(max_value):
+ value = min(value, max_value)
+
+ return value
+
+
+class EMA(nn.Module):
+ """
+ Implements exponential moving average shadowing for your model.
+
+ Utilizes an inverse decay schedule to manage longer term training runs.
+ By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
+
+ @crowsonkb's notes on EMA Warmup:
+
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
+ good values for models you plan to train for a million or more steps (reaches decay
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
+ 215.4k steps).
+
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 1.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ def __init__(
+ self,
+ model,
+ ema_model=None,
+ # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
+ beta=0.9999,
+ update_after_step=100,
+ update_every=10,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ param_or_buffer_names_no_ema=set(),
+ ignore_names=set(),
+ ignore_startswith_names=set(),
+ include_online_model=True
+ # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
+ ):
+ super().__init__()
+ self.beta = beta
+
+ # whether to include the online model within the module tree, so that state_dict also saves it
+
+ self.include_online_model = include_online_model
+
+ if include_online_model:
+ self.online_model = model
+ else:
+ self.online_model = [model] # hack
+
+ # ema model
+
+ self.ema_model = ema_model
+
+ if not exists(self.ema_model):
+ try:
+ self.ema_model = copy.deepcopy(model)
+ except:
+ print('Your model was not copyable. Please make sure you are not using any LazyLinear')
+ exit()
+
+ self.ema_model.requires_grad_(False)
+
+ self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float}
+ self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float}
+
+ self.update_every = update_every
+ self.update_after_step = update_after_step
+
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+
+ assert isinstance(param_or_buffer_names_no_ema, (set, list))
+ self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
+
+ self.ignore_names = ignore_names
+ self.ignore_startswith_names = ignore_startswith_names
+
+ self.register_buffer('initted', torch.Tensor([False]))
+ self.register_buffer('step', torch.tensor([0]))
+
+ @property
+ def model(self):
+ return self.online_model if self.include_online_model else self.online_model[0]
+
+ def restore_ema_model_device(self):
+ device = self.initted.device
+ self.ema_model.to(device)
+
+ def get_params_iter(self, model):
+ for name, param in model.named_parameters():
+ if name not in self.parameter_names:
+ continue
+ yield name, param
+
+ def get_buffers_iter(self, model):
+ for name, buffer in model.named_buffers():
+ if name not in self.buffer_names:
+ continue
+ yield name, buffer
+
+ def copy_params_from_model_to_ema(self):
+ for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model),
+ self.get_params_iter(self.model)):
+ ma_params.data.copy_(current_params.data)
+
+ for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model),
+ self.get_buffers_iter(self.model)):
+ ma_buffers.data.copy_(current_buffers.data)
+
+ def get_current_decay(self):
+ epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.)
+ value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
+
+ if epoch <= 0:
+ return 0.
+
+ return clamp(value, min_value=self.min_value, max_value=self.beta)
+
+ def update(self):
+ step = self.step.item()
+ self.step += 1
+
+ if (step % self.update_every) != 0:
+ return
+
+ if step <= self.update_after_step:
+ self.copy_params_from_model_to_ema()
+ return
+
+ if not self.initted.item():
+ self.copy_params_from_model_to_ema()
+ self.initted.data.copy_(torch.Tensor([True]))
+
+ self.update_moving_average(self.ema_model, self.model)
+
+ @torch.no_grad()
+ def update_moving_average(self, ma_model, current_model):
+ current_decay = self.get_current_decay()
+
+ for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model),
+ self.get_params_iter(ma_model)):
+ if name in self.ignore_names:
+ continue
+
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+ continue
+
+ if name in self.param_or_buffer_names_no_ema:
+ ma_params.data.copy_(current_params.data)
+ continue
+
+ ma_params.data.lerp_(current_params.data, 1. - current_decay)
+
+ for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model),
+ self.get_buffers_iter(ma_model)):
+ if name in self.ignore_names:
+ continue
+
+ if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+ continue
+
+ if name in self.param_or_buffer_names_no_ema:
+ ma_buffer.data.copy_(current_buffer.data)
+ continue
+
+ ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay)
+
+ def __call__(self, *args, **kwargs):
+ return self.ema_model(*args, **kwargs)
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d35e8da2cb64e2512ddb4031e120c8b7aed78832
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py
@@ -0,0 +1,1086 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from .loss import LPIPSWithDiscriminator
+
+# from ldm.util import instantiate_from_config
+# from ldm.modules.attention import LinearAttention
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = (curr_res[0] // 2, curr_res[1] // 2)
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = (resolution[0] // 2**(self.num_resolutions-1), resolution[1] // 2**(self.num_resolutions-1))
+ self.z_shape = (1,z_channels,curr_res[0],curr_res[1])
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = (curr_res[0] * 2, curr_res[1] * 2)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+class AutoencoderKL(nn.Module):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.down_ratio = 2 ** (len(ddconfig['ch_mult']) - 1)
+ self.loss = LPIPSWithDiscriminator(**lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), use_ema=True):
+ sd = torch.load(path, map_location="cpu")
+ sd_keys = sd.keys()
+ if 'ema' in list(sd.keys()) and use_ema:
+ sd = sd['ema']
+ new_sd = {}
+ for k in sd.keys():
+ if k.startswith("ema_model."):
+ new_k = k[10:] # remove ema_model.
+ new_sd[new_k] = sd[k]
+ sd = new_sd
+ else:
+ if 'model' in sd_keys:
+ sd = sd["model"]
+ elif 'state_dict' in sd_keys:
+ sd = sd['state_dict']
+ else:
+ sd = sd
+ # raise ValueError("")
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ msg = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+ print('==>Load AutoEncoder Info: ', msg)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, inputs, optimizer_idx, global_step):
+ # inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, global_step,
+ last_layer=self.get_last_layer(), split="train")
+ # self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ # self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss, log_dict_ae
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ # self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss, log_dict_disc
+
+ def validation_step(self, inputs, global_step):
+ # inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ # self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ # self.log_dict(log_dict_ae)
+ # self.log_dict(log_dict_disc)
+ return log_dict_ae, log_dict_disc
+
+ def validate_img(self, inputs):
+ reconstructions, posterior = self(inputs)
+ return reconstructions
+
+ # def configure_optimizers(self):
+ # lr = self.learning_rate
+ # opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ # list(self.decoder.parameters())+
+ # list(self.quant_conv.parameters())+
+ # list(self.post_quant_conv.parameters()),
+ # lr=lr, betas=(0.5, 0.9))
+ # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ # lr=lr, betas=(0.5, 0.9))
+ # return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+ '''
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+ '''
+
+if __name__ == '__main__':
+ ddconfig = {'double_z': True,
+ 'z_channels': 4,
+ 'resolution': (240, 960),
+ 'in_channels': 3,
+ 'out_ch': 3,
+ 'ch': 128,
+ 'ch_mult': [ 1,2,4 ], # num_down = len(ch_mult)-1
+ 'num_res_blocks': 2,
+ 'attn_resolutions': [ ],
+ 'dropout': 0.0}
+ lossconfig = {'disc_start': 50001,
+ 'kl_weight': 0.000001,
+ 'disc_weight': 0.5}
+ model = AutoencoderKL(ddconfig, lossconfig, embed_dim=4,
+ ckpt_path='/pretrain_weights/model-kl-f8.ckpt', )
+ '''
+ from torch.optim import AdamW
+ optimizer = AdamW(model.parameters(), lr=0.01)
+ lr_lambda = lambda iter: (1 - iter / 1000) ** 0.95
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
+ for s in range(1000):
+ lr_scheduler.step()
+ cur_lr = optimizer.param_groups[0]['lr']
+ print(cur_lr)
+ '''
+ x = torch.rand(1, 3, 240, 960)
+ with torch.no_grad():
+ y = model(x)
+ pass
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ff9a952b30ba4bdfd3df5a6c7c38a7115327bc
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py
@@ -0,0 +1,395 @@
+import os, yaml, pickle, shutil, tarfile, glob
+import cv2
+import albumentations
+import PIL
+import numpy as np
+import torchvision.transforms.functional as TF
+# from omegaconf import OmegaConf
+from functools import partial
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset, Subset
+
+import taming.data.utils as tdu
+from custom_controlnet_aux.diffusion_edge.taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from custom_controlnet_aux.diffusion_edge.taming.data.imagenet import ImagePaths
+
+# from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
+
+
+def synset2idx(path_to_yaml="data/index_synset.yaml"):
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ return dict((v,k) for k,v in di2s.items())
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config
+ # if not type(self.config)==dict:
+ # self.config = OmegaConf.to_container(self.config)
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._prepare_human_to_integer_label()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _prepare_human_to_integer_label(self):
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
+ if (not os.path.exists(self.human2integer)):
+ download(URL, self.human2integer)
+ with open(self.human2integer, "r") as f:
+ lines = f.read().splitlines()
+ assert len(lines) == 1000
+ self.human2integer_dict = dict()
+ for line in lines:
+ value, key = line.split(":")
+ self.human2integer_dict[key] = int(value)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ if not self.keep_orig_class_label:
+ self.class_labels = [class_dict[s] for s in self.synsets]
+ else:
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+
+ if self.process_images:
+ # self.size = retrieve(self.config, "size", default=256)
+ self.size = self.config.get("size", default=256)
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=self.size,
+ random_crop=self.random_crop,
+ )
+ else:
+ self.data = self.abspaths
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.process_images = process_images
+ self.data_root = data_root
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ # self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True)
+ self.random_crop = self.config.get("random_crop", default=True)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.data_root = data_root
+ self.process_images = process_images
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ # self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False)
+ self.random_crop = self.config.get("random_crop", default=False)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+
+class ImageNetSR(Dataset):
+ def __init__(self, size=None,
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
+ random_crop=True):
+ """
+ Imagenet Superresolution Dataloader
+ Performs following ops in order:
+ 1. crops a crop of size s from image either as random or center crop
+ 2. resizes crop to size with cv2.area_interpolation
+ 3. degrades resized crop with degradation_fn
+
+ :param size: resizing to size after cropping
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
+ :param downscale_f: Low Resolution Downsample factor
+ :param min_crop_f: determines crop size s,
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
+ :param max_crop_f: ""
+ :param data_root:
+ :param random_crop:
+ """
+ self.base = self.get_base()
+ assert size
+ assert (size / downscale_f).is_integer()
+ self.size = size
+ self.LR_size = int(size / downscale_f)
+ self.min_crop_f = min_crop_f
+ self.max_crop_f = max_crop_f
+ assert(max_crop_f <= 1.)
+ self.center_crop = not random_crop
+
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
+
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+
+ if degradation == "bsrgan":
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
+
+ elif degradation == "bsrgan_light":
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
+
+ else:
+ interpolation_fn = {
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
+ }[degradation]
+
+ self.pil_interpolation = degradation.startswith("pil_")
+
+ if self.pil_interpolation:
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
+
+ else:
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
+ interpolation=interpolation_fn)
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = Image.open(example["file_path_"])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ image = np.array(image).astype(np.uint8)
+
+ min_side_len = min(image.shape[:2])
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ image = self.cropper(image=image)["image"]
+ image = self.image_rescaler(image=image)["image"]
+
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+
+ return example
+
+
+class ImageNetSRTrain(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetTrain(process_images=False,)
+ return Subset(dset, indices)
+
+
+class ImageNetSRValidation(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetValidation(process_images=False,)
+ return Subset(dset, indices)
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f00750858bd9597a3e312c75fd81d6d2c30c1d
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py
@@ -0,0 +1,113 @@
+import torch
+import torch.nn as nn
+import sys
+# .path.append()
+from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import *
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, *, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + \
+ F.mse_loss(inputs, reconstructions, reduction="none")
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ac84c5d7788f049a2f8e82d35ab075f2065718d
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py
@@ -0,0 +1,1009 @@
+import fvcore.common.config
+import torch
+import torch.nn as nn
+import math
+import torch.nn.functional as F
+from functools import partial
+from einops import rearrange, reduce
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.efficientnet import efficientnet_b7, EfficientNet_B7_Weights
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.resnet import resnet101, ResNet101_Weights
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.swin_transformer import swin_b, Swin_B_Weights
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.vgg import vgg16, VGG16_Weights
+
+from custom_controlnet_aux.util import custom_torch_download
+# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.wcc import fft
+### Compared to unet4:
+# 1. add FFT-Conv on the mid feature.
+######## Attention Layer ##########
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+ # self.class_token_pos = nn.Parameter(torch.zeros(1, 1, num_pos_feats * 2))
+ # self.class_token_pos
+
+ def forward(self, x):
+ # x: b, h, w, d
+ num_feats = x.shape[3]
+ num_pos_feats = num_feats // 2
+ # mask = tensor_list.mask
+ mask = torch.zeros(x.shape[0], x.shape[1], x.shape[2], device=x.device).to(torch.bool)
+ batch = mask.shape[0]
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-5
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ # pos = torch.cat((pos_y, pos_x), dim=3).flatten(1, 2)
+ pos = torch.cat((pos_y, pos_x), dim=3).contiguous()
+ '''
+ pos_x: b ,h, w, d//2
+ pos_y: b, h, w, d//2
+ pos: b, h, w, d
+ '''
+ return pos
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, feature_size, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(feature_size[0], num_pos_feats)
+ self.col_embed = nn.Embedding(feature_size[1], num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, x):
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return torch.cat([x, pos], dim=1)
+
+
+class ChannelAttention(nn.Module):
+ def __init__(self, in_planes, ratio=8):
+ super(ChannelAttention, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
+ self.relu1 = nn.ReLU()
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
+ out = avg_out + max_out
+ return self.sigmoid(out) * x
+
+class SpatialAtt(nn.Module):
+ def __init__(self, in_dim):
+ super(SpatialAtt, self).__init__()
+ self.map = nn.Conv2d(in_dim, 1, 1)
+ self.q_conv = nn.Conv2d(1, 1, 1)
+ self.k_conv = nn.Conv2d(1, 1, 1)
+ self.activation = nn.Softsign()
+
+ def forward(self, x):
+ b, _, h, w = x.shape
+ att = self.map(x) # b, 1, h, w
+ q = self.q_conv(att) # b, 1, h, w
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = self.k_conv(att)
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ att = rearrange(att, 'b c h w -> b (h w) c')
+ att = F.softmax(q @ k, dim=-1) @ att # b, hw, 1
+ att = att.reshape(b, 1, h, w)
+ return self.activation(att) * x
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ # self.fc1 = nn.Linear(in_features, hidden_features)
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
+ # self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+
+class BasicAttetnionLayer(nn.Module):
+ def __init__(self, embed_dim=128, nhead=8, ffn_dim=512, window_size1=[4, 4],
+ window_size2=[1, 1], dropout=0.1):
+ super().__init__()
+ self.window_size1 = window_size1
+ self.window_size2 = window_size2
+ self.avgpool_q = nn.AvgPool2d(kernel_size=window_size1)
+ self.avgpool_k = nn.AvgPool2d(kernel_size=window_size2)
+ self.softmax = nn.Softmax(dim=-1)
+ self.nhead = nhead
+
+ self.q_lin = nn.Linear(embed_dim, embed_dim)
+ self.k_lin = nn.Linear(embed_dim, embed_dim)
+ self.v_lin = nn.Linear(embed_dim, embed_dim)
+
+ self.mlp = Mlp(in_features=embed_dim, hidden_features=ffn_dim, drop=dropout)
+ self.pos_enc = PositionEmbeddingSine(embed_dim)
+ self.concat_conv = nn.Conv2d(2 * embed_dim, embed_dim, 1)
+ self.gn = nn.GroupNorm(8, embed_dim)
+
+ self.out_conv = nn.Conv2d(embed_dim, embed_dim, 1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x1, x2): # x1 for q (conditional input), x2 for k,v
+ B, C1, H1, W1 = x1.shape
+ _, C2, H2, W2 = x2.shape
+ # x1 = x1.permute(0, 2, 3, 1).contiguous() # B, H1, W1, C1
+ shortcut = x2 + self.concat_conv(torch.cat(
+ [F.interpolate(x1, size=(H2, W2), mode='bilinear', align_corners=True),
+ x2], dim=1))
+ shortcut = self.gn(shortcut)
+ pad_l = pad_t = 0
+ pad_r = (self.window_size1[1] - W1 % self.window_size1[1]) % self.window_size1[1]
+ pad_b = (self.window_size1[0] - H1 % self.window_size1[0]) % self.window_size1[0]
+ x1 = F.pad(x1, (pad_l, pad_r, pad_t, pad_b, 0, 0))
+ _, _, H1p, W1p = x1.shape
+ # x2 = x2.permute(0, 2, 3, 1).contiguous() # B, H2, W2, C2
+ pad_l = pad_t = 0
+ pad_r = (self.window_size2[1] - W2 % self.window_size2[1]) % self.window_size2[1]
+ pad_b = (self.window_size2[0] - H2 % self.window_size2[0]) % self.window_size2[0]
+ x2 = F.pad(x2, (pad_l, pad_r, pad_t, pad_b, 0, 0))
+ _, _, H2p, W2p = x2.shape
+ # x1g = x1 #B, C1, H1p, W1p
+ # x2g = x2 #B, C2, H2p, W2p
+ x1_s = self.avgpool_q(x1)
+ qg = self.avgpool_q(x1).permute(0, 2, 3, 1).contiguous()
+ qg = qg + self.pos_enc(qg)
+ qg= qg.view(B, -1, C2)
+ kg = self.avgpool_k(x2).permute(0, 2, 3, 1).contiguous()
+ kg = kg + self.pos_enc(kg)
+ kg = kg.view(B, -1, C1)
+ num_window_q = qg.shape[1]
+ num_window_k = kg.shape[1]
+ qg = self.q_lin(qg).reshape(B, num_window_q, self.nhead, C1 // self.nhead).permute(0, 2, 1,
+ 3).contiguous()
+ kg2 = self.k_lin(kg).reshape(B, num_window_k, self.nhead, C1 // self.nhead).permute(0, 2, 1,
+ 3).contiguous()
+ vg = self.v_lin(kg).reshape(B, num_window_k, self.nhead, C1 // self.nhead).permute(0, 2, 1,
+ 3).contiguous()
+ kg = kg2
+ attn = (qg @ kg.transpose(-2, -1))
+ attn = self.softmax(attn)
+ qg = (attn @ vg).transpose(1, 2).reshape(B, num_window_q, C1)
+ qg = qg.transpose(1, 2).reshape(B, C1, H1p // self.window_size1[0], W1p // self.window_size1[1])
+ # qg = F.interpolate(qg, size=(H1p, W1p), mode='bilinear', align_corners=False)
+ x1_s = x1_s + qg
+ x1_s = x1_s + self.mlp(x1_s)
+ x1_s = F.interpolate(x1_s, size=(H2, W2), mode='bilinear', align_corners=True)
+ x1_s = shortcut + self.out_conv(x1_s)
+ # x1_s = self.out_norm(x1_s)
+ return x1_s
+
+class RelationNet(nn.Module):
+ def __init__(self, in_channel1=128, in_channel2=128, nhead=8, layers=3, embed_dim=128, ffn_dim=512,
+ window_size1= [4, 4], window_size2=[1, 1]):
+ # self.attention = BasicAttetnionLayer(embed_dim=embed_dim, nhead=nhead, ffn_dim=ffn_dim,
+ # window_size1=window_size1, window_size2=window_size2, dropout=0.1)
+ super().__init__()
+ self.layers = layers
+ self.input_conv1 = nn.Sequential(
+ nn.Conv2d(in_channel1, embed_dim, 1),
+ nn.BatchNorm2d(embed_dim, momentum=0.03, eps=0.001),
+ )
+ self.input_conv2 = nn.Sequential(
+ nn.Conv2d(in_channel2, embed_dim, 1),
+ nn.BatchNorm2d(embed_dim, momentum=0.03, eps=0.001),
+ )
+ # self.input_conv1 = ConvModule(in_channel1,
+ # embed_dim,
+ # 1,
+ # norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ # act_cfg=None)
+ # self.input_conv2 = ConvModule(in_channel2,
+ # embed_dim,
+ # 1,
+ # norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ # act_cfg=None)
+ # self.input_conv2 = nn.Linear(in_channel2, embed_dim)
+ self.attentions = nn.ModuleList()
+ for i in range(layers):
+ self.attentions.append(
+ BasicAttetnionLayer(embed_dim=embed_dim, nhead=nhead, ffn_dim=ffn_dim,
+ window_size1=window_size1, window_size2=window_size2, dropout=0.1)
+ )
+
+ def forward(self, cond, feat):
+ # cluster = cluster.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1)
+ cond = self.input_conv1(cond)
+ feat = self.input_conv2(feat)
+ for att in self.attentions:
+ feat = att(cond, feat)
+ return feat
+
+
+
+################# U-Net model defenition ####################
+
+def exists(x):
+ return x is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+def identity(t, *args, **kwargs):
+ return t
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def has_int_squareroot(num):
+ return (math.sqrt(num) ** 2) == num
+
+def num_to_groups(num, divisor):
+ groups = num // divisor
+ remainder = num % divisor
+ arr = [divisor] * groups
+ if remainder > 0:
+ arr.append(remainder)
+ return arr
+
+def convert_image_to_fn(img_type, image):
+ if image.mode != img_type:
+ return image.convert(img_type)
+ return image
+
+# normalization functions
+
+def normalize_to_neg_one_to_one(img):
+ return img * 2 - 1
+
+def unnormalize_to_zero_to_one(t):
+ return (t + 1) * 0.5
+
+# small helper modules
+
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, *args, **kwargs):
+ return self.fn(x, *args, **kwargs) + x
+
+def Upsample(dim, dim_out = None):
+ return nn.Sequential(
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
+ nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
+ )
+
+def Downsample(dim, dim_out = None):
+ return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
+
+class WeightStandardizedConv2d(nn.Conv2d):
+ """
+ https://arxiv.org/abs/1903.10520
+ weight standardization purportedly works synergistically with group normalization
+ """
+ def forward(self, x):
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+
+ weight = self.weight
+ mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
+ var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
+ normalized_weight = (weight - mean) * (var + eps).rsqrt()
+
+ return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
+
+ def forward(self, x):
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+ mean = torch.mean(x, dim = 1, keepdim = True)
+ return (x - mean) * (var + eps).rsqrt() * self.g
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.fn = fn
+ self.norm = LayerNorm(dim)
+
+ def forward(self, x):
+ x = self.norm(x)
+ return self.fn(x)
+
+# sinusoidal positional embeds
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size=256, scale=1.0):
+ super().__init__()
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ def forward(self, x):
+ x_proj = x[:, None] * self.W[None, :] * 2 * math.pi
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+
+class RandomOrLearnedSinusoidalPosEmb(nn.Module):
+ """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
+
+ def __init__(self, dim, is_random = False):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
+
+ def forward(self, x):
+ x = rearrange(x, 'b -> b 1')
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
+ fouriered = torch.cat((x, fouriered), dim = -1)
+ return fouriered
+
+# building block modules
+
+class Block(nn.Module):
+ def __init__(self, dim, dim_out, groups = 8):
+ super().__init__()
+ self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
+ self.norm = nn.GroupNorm(groups, dim_out)
+ self.act = nn.SiLU()
+
+ def forward(self, x, scale_shift = None):
+ x = self.proj(x)
+ x = self.norm(x)
+
+ if exists(scale_shift):
+ scale, shift = scale_shift
+ x = x * (scale + 1) + shift
+
+ x = self.act(x)
+ return x
+
+class BlockFFT(nn.Module):
+ def __init__(self, dim, h, w, groups = 8):
+ super().__init__()
+ # self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
+ self.complex_weight = nn.Parameter(torch.randn(dim, h, w//2+1, 2, dtype=torch.float32) * 0.02)
+ # self.complex_weight = nn.Parameter(torch.normal(mean=0, std=0.01, size=(dim, h, w // 2 + 1, 2), dtype=torch.float32))
+ # self.norm = nn.GroupNorm(groups, dim)
+ # self.act = nn.SiLU()
+
+ def forward(self, x, scale_shift = None):
+ B, C, H, W = x.shape
+ x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho')
+ x = x * torch.view_as_complex(self.complex_weight)
+ x = torch.fft.irfft2(x, s=(H, W), dim=(2, 3), norm='ortho')
+ x = x.reshape(B, C, H, W)
+
+ return x
+
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_emb_dim, dim_out * 2)
+ ) if exists(time_emb_dim) else None
+
+ self.block1 = Block(dim, dim_out, groups = groups)
+ self.block2 = Block(dim_out, dim_out, groups = groups)
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, time_emb = None):
+
+ scale_shift = None
+ if exists(self.mlp) and exists(time_emb):
+ time_emb = self.mlp(time_emb)
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
+ scale_shift = time_emb.chunk(2, dim = 1)
+
+ h = self.block1(x, scale_shift = scale_shift)
+
+ h = self.block2(h)
+
+ return h + self.res_conv(x)
+
+class ResnetBlockFFT(nn.Module):
+ def __init__(self, dim, dim_out, h, w, *, time_emb_dim = None, groups = 8):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_emb_dim, dim_out * 2)
+ ) if exists(time_emb_dim) else None
+
+ self.block1 = Block(dim, dim_out, groups = groups)
+ self.block2 = BlockFFT(dim_out, h, w, groups = groups)
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, time_emb = None):
+
+ scale_shift = None
+ if exists(self.mlp) and exists(time_emb):
+ time_emb = self.mlp(time_emb)
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
+ scale_shift = time_emb.chunk(2, dim = 1)
+
+ h = self.block1(x, scale_shift = scale_shift)
+
+ h = self.block2(h)
+
+ return h + self.res_conv(x)
+
+class ResnetDownsampleBlock(nn.Module):
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_emb_dim, dim_out * 2)
+ ) if exists(time_emb_dim) else None
+
+ self.block1 = Block(dim, dim_out, groups = groups)
+ self.block2 = nn.Sequential(
+ WeightStandardizedConv2d(dim_out, dim_out, 3, stride=2, padding=1),
+ nn.GroupNorm(groups, dim_out),
+ nn.SiLU()
+ )
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, time_emb = None):
+
+ scale_shift = None
+ if exists(self.mlp) and exists(time_emb):
+ time_emb = self.mlp(time_emb)
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
+ scale_shift = time_emb.chunk(2, dim = 1)
+
+ h = self.block1(x, scale_shift = scale_shift)
+
+ h = self.block2(h)
+
+ return h + self.res_conv(
+ F.interpolate(x, size=h.shape[-2:], mode="bilinear")
+ )
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads = 4, dim_head = 32):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv2d(hidden_dim, dim, 1),
+ LayerNorm(dim)
+ )
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
+
+ q = q.softmax(dim = -2)
+ k = k.softmax(dim = -1)
+
+ q = q * self.scale
+ v = v / (h * w)
+
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
+
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
+ return self.to_out(out)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 4, dim_head = 32):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x).chunk(3, dim=1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
+
+ q = q * self.scale
+
+ sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
+ attn = sim.softmax(dim=-1)
+ out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
+
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
+ return self.to_out(out)
+
+
+class ConditionEncoder(nn.Module):
+ def __init__(self,
+ down_dim_mults=(2, 4, 8),
+ dim=64,
+ in_dim=1,
+ out_dim=64):
+ super(ConditionEncoder, self).__init__()
+ self.init_conv = nn.Sequential(
+ nn.Conv2d(in_dim, dim, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(num_groups=min(dim // 4, 8), num_channels=dim),
+ )
+ self.num_resolutions = len(down_dim_mults)
+ self.downs = nn.ModuleList()
+ in_mults = (1,) + tuple(down_dim_mults[:-1])
+ in_dims = [mult*dim for mult in in_mults]
+ out_dims = [mult*dim for mult in down_dim_mults]
+ for i_level in range(self.num_resolutions):
+ block_in = in_dims[i_level]
+ block_out = out_dims[i_level]
+ self.downs.append(ResnetDownsampleBlock(dim=block_in,
+ dim_out=block_out))
+ if self.num_resolutions < 1:
+ self.out_conv = nn.Conv2d(dim, out_dim, 1)
+ else:
+ self.out_conv = nn.Conv2d(out_dims[-1], out_dim, 1)
+
+ def forward(self, x):
+ x = self.init_conv(x)
+ for down_layer in self.downs:
+ x = down_layer(x)
+ x = self.out_conv(x)
+ return x
+
+
+class Unet(nn.Module):
+ def __init__(
+ self,
+ dim,
+ init_dim=None,
+ out_dim=None,
+ dim_mults=(1, 2, 4, 8),
+ cond_in_dim=1,
+ cond_dim=64,
+ cond_dim_mults=(2, 4, 8),
+ channels=1,
+ out_mul=1,
+ self_condition=False,
+ resnet_block_groups=8,
+ learned_variance=False,
+ learned_sinusoidal_cond=False,
+ random_fourier_features=False,
+ learned_sinusoidal_dim=16,
+ window_sizes1=[[16, 16], [8, 8], [4, 4], [2, 2]],
+ window_sizes2=[[16, 16], [8, 8], [4, 4], [2, 2]],
+ fourier_scale=16,
+ ckpt_path=None,
+ ignore_keys=[],
+ cfg={},
+ **kwargs
+ ):
+ super().__init__()
+
+ # determine dimensions
+ self.cond_pe = cfg.get('cond_pe', False)
+ num_pos_feats = cfg.num_pos_feats if self.cond_pe else 0
+ self.channels = channels
+ self.self_condition = self_condition
+ input_channels = channels * (2 if self_condition else 1)
+
+ init_dim = default(init_dim, dim)
+ # self.init_conv_mask = nn.Sequential(
+ # nn.Conv2d(cond_in_dim, cond_dim, 3, padding=1),
+ # nn.GroupNorm(num_groups=min(init_dim // 4, 8), num_channels=init_dim),
+ # nn.SiLU(),
+ # nn.Conv2d(cond_dim, cond_dim, 3, padding=1),
+ # )
+ # self.init_conv_mask = ConditionEncoder(down_dim_mults=cond_dim_mults, dim=cond_dim,
+ # in_dim=cond_in_dim, out_dim=init_dim)
+
+ if cfg.cond_net == 'effnet':
+ f_condnet = 48
+ if cfg.get('without_pretrain', False):
+ self.init_conv_mask = efficientnet_b7()
+ else:
+ self.init_conv_mask = efficientnet_b7(weights=EfficientNet_B7_Weights)
+ elif cfg.cond_net == 'resnet':
+ f_condnet = 256
+ if cfg.get('without_pretrain', False):
+ self.init_conv_mask = resnet101()
+ else:
+ self.init_conv_mask = resnet101(weights=ResNet101_Weights)
+ elif cfg.cond_net == 'swin':
+ f_condnet = 128
+ if cfg.get('without_pretrain', False):
+ self.init_conv_mask = swin_b()
+ else:
+ swin_b_model = swin_b(pretrained=False)
+ swin_b_model.load_state_dict(torch.load(custom_torch_download(filename="swin_b-68c6b09e.pth")), strict=False)
+ self.init_conv_mask = swin_b_model
+ elif cfg.cond_net == 'vgg':
+ f_condnet = 128
+ if cfg.get('without_pretrain', False):
+ self.init_conv_mask = vgg16()
+ else:
+ self.init_conv_mask = vgg16(weights=VGG16_Weights)
+ else:
+ raise NotImplementedError
+ self.init_conv = nn.Sequential(
+ nn.Conv2d(input_channels + f_condnet, init_dim, 7, padding=3),
+ nn.GroupNorm(num_groups=min(init_dim // 4, 8), num_channels=init_dim),
+ )
+
+ if self.cond_pe:
+ self.cond_pos_embedding = nn.Sequential(
+ PositionEmbeddingLearned(
+ feature_size=cfg.cond_feature_size, num_pos_feats=cfg.num_pos_feats//2),
+ nn.Conv2d(num_pos_feats + init_dim, init_dim, 1)
+ )
+ # self.init_conv_mask = nn.Conv2d(1, init_dim, 7, padding=3)
+
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
+ dims_rev = dims[::-1]
+ in_out = list(zip(dims[:-1], dims[1:]))
+ self.projects = nn.ModuleList()
+ print(cfg.cond_net)
+ if cfg.cond_net == 'effnet':
+ self.projects.append(nn.Conv2d(48, dims[0], 1))
+ self.projects.append(nn.Conv2d(80, dims[1], 1))
+ self.projects.append(nn.Conv2d(224, dims[2], 1))
+ self.projects.append(nn.Conv2d(640, dims[3], 1))
+ print(len(self.projects))
+ elif cfg.cond_net == 'vgg':
+ self.projects.append(nn.Conv2d(128, dims[0], 1))
+ self.projects.append(nn.Conv2d(256, dims[1], 1))
+ self.projects.append(nn.Conv2d(512, dims[2], 1))
+ self.projects.append(nn.Conv2d(512, dims[3], 1))
+ else:
+ self.projects.append(nn.Conv2d(f_condnet, dims[0], 1))
+ self.projects.append(nn.Conv2d(f_condnet*2, dims[1], 1))
+ self.projects.append(nn.Conv2d(f_condnet*4, dims[2], 1))
+ self.projects.append(nn.Conv2d(f_condnet*8, dims[3], 1))
+ #print(len(self.projects))
+
+ block_klass = partial(ResnetBlock, groups = resnet_block_groups)
+
+ # time embeddings
+
+ time_dim = dim * 4
+
+ self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
+
+ if self.random_or_learned_sinusoidal_cond:
+ sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
+ fourier_dim = learned_sinusoidal_dim + 1
+ else:
+ sinu_pos_emb = GaussianFourierProjection(dim//2, scale=fourier_scale)
+ fourier_dim = dim
+
+ self.time_mlp = nn.Sequential(
+ sinu_pos_emb,
+ nn.Linear(fourier_dim, time_dim),
+ nn.GELU(),
+ nn.Linear(time_dim, time_dim)
+ )
+
+ # layers
+
+ self.downs = nn.ModuleList([])
+ self.downs_mask = nn.ModuleList([])
+ self.ups = nn.ModuleList([])
+ self.relation_layers_down = nn.ModuleList([])
+ self.relation_layers_up = nn.ModuleList([])
+ self.ups2 = nn.ModuleList([])
+ self.relation_layers_up2 = nn.ModuleList([])
+ num_resolutions = len(in_out)
+ input_size = cfg.get('input_size', [80, 80])
+ feature_size_list = [[int(input_size[0]/2**k), int(input_size[1]/2**k)] for k in range(len(dim_mults))]
+
+
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (num_resolutions - 1)
+
+ self.downs.append(nn.ModuleList([
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
+ Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
+ ]))
+ # self.downs_mask.append(nn.ModuleList([
+ # block_klass(dim_in, dim_in, time_emb_dim=time_dim),
+ # # block_klass(dim_in, dim_in, time_emb_dim=time_dim),
+ # Residual(PreNorm(dim_in, LinearAttention(dim_in))),
+ # Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
+ # ]))
+ self.relation_layers_down.append(RelationNet(in_channel1=dims[ind], in_channel2=dims[ind], nhead=8,
+ layers=1, embed_dim=dims[ind], ffn_dim=dims[ind]*2,
+ window_size1=window_sizes1[ind], window_size2=window_sizes2[ind])
+ )
+
+ mid_dim = dims[-1]
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
+ self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
+ self.decouple1 = nn.Sequential(
+ nn.GroupNorm(num_groups=min(mid_dim // 4, 8), num_channels=mid_dim),
+ nn.Conv2d(mid_dim, mid_dim, 3, padding=1),
+ BlockFFT(mid_dim, input_size[0]//8, input_size[1]//8),
+ )
+ self.decouple2 = nn.Sequential(
+ nn.GroupNorm(num_groups=min(mid_dim // 4, 8), num_channels=mid_dim),
+ nn.Conv2d(mid_dim, mid_dim, 3, padding=1),
+ BlockFFT(mid_dim, input_size[0]//8, input_size[1]//8),
+ )
+
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
+ is_last = ind == (len(in_out) - 1)
+
+ self.ups.append(nn.ModuleList([
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
+ Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
+ ]))
+ self.relation_layers_up.append(RelationNet(in_channel1=dims_rev[ind+1], in_channel2=dims_rev[ind],
+ nhead=8, layers=1, embed_dim=dims_rev[ind],
+ ffn_dim=dims_rev[ind] * 2,
+ window_size1=window_sizes1[::-1][ind],
+ window_size2=window_sizes2[::-1][ind])
+ )
+ self.ups2.append(nn.ModuleList([
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
+ Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1)
+ ]))
+ self.relation_layers_up2.append(RelationNet(in_channel1=dims_rev[ind + 1], in_channel2=dims_rev[ind],
+ nhead=8, layers=1, embed_dim=dims_rev[ind],
+ ffn_dim=dims_rev[ind] * 2,
+ window_size1=window_sizes1[::-1][ind],
+ window_size2=window_sizes2[::-1][ind])
+ )
+
+ default_out_dim = channels * (1 if not learned_variance else 2)
+ self.out_dim = default(out_dim, default_out_dim)
+
+ self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
+ self.final_conv = nn.Conv2d(dim, self.out_dim * out_mul, 1)
+
+ self.final_res_block2 = block_klass(dim * 2, dim, time_emb_dim = time_dim)
+ self.final_conv2 = nn.Conv2d(dim, self.out_dim, 1)
+
+ # self.init_weights()
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ fix_bb = cfg.get('fix_bb', True)
+ if fix_bb:
+ for n, p in self.init_conv_mask.named_parameters():
+ p.requires_grad = False
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["model"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ msg = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+ print('==>Load Unet Info: ', msg)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x, time, mask, x_self_cond = None, **kwargs):
+ if self.self_condition:
+ x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
+ x = torch.cat((x_self_cond, x), dim = 1)
+ sigma = time.reshape(-1, 1, 1, 1)
+ eps = 1e-4
+ c_skip1 = 1 - sigma
+ c_skip2 = torch.sqrt(sigma)
+ # c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
+ c_out1 = sigma / torch.sqrt(sigma ** 2 + 1)
+ c_out2 = torch.sqrt(1 - sigma) / torch.sqrt(sigma ** 2 + 1)
+ c_in = 1
+
+ x_clone = x.clone()
+ x = c_in * x
+ # mask = torch.cat([], dim=1)
+ hm = self.init_conv_mask(mask)
+ # if self.cond_pe:
+ # m = self.cond_pos_embedding(m)
+ x = self.init_conv(torch.cat([x, F.interpolate(hm[0], size=x.shape[-2:], mode="bilinear")], dim=1))
+ r = x.clone()
+
+ t = self.time_mlp(torch.log(time)/4)
+
+ h = []
+ h2 = []
+ for i, layer in enumerate(self.projects):
+ # print(hm[i].shape)
+ hm[i] = layer(hm[i])
+ hm2 = []
+ for i in range(len(hm)):
+ hm2.append(hm[i].clone())
+ # hm = []
+ # hm2 = []
+ for i, ((block1, block2, attn, downsample), relation_layer) \
+ in enumerate(zip(self.downs, self.relation_layers_down)):
+ x = block1(x, t)
+ h.append(x)
+ h2.append(x.clone())
+ # m = m_block(m, t)
+ # hm.append(m)
+ # hm2.append(m.clone())
+
+ x = relation_layer(hm[i], x)
+
+ x = block2(x, t)
+ x = attn(x)
+ h.append(x)
+ h2.append(x.clone())
+
+ x = downsample(x)
+ # m = m_downsample(m)
+
+
+ # x = x + F.interpolate(hm[-1], size=x.shape[2:], mode="bilinear", align_corners=True)
+ x = self.mid_block1(x, t)
+ x = self.mid_attn(x)
+ x = self.mid_block2(x, t)
+ x1 = x + self.decouple1(x)
+ x2 = x + self.decouple2(x)
+
+ x = x1
+ for (block1, block2, attn, upsample), relation_layer in zip(self.ups, self.relation_layers_up):
+ x = torch.cat((x, h.pop()), dim = 1)
+ x = block1(x, t)
+ x = relation_layer(hm.pop(), x)
+ x = torch.cat((x, h.pop()), dim = 1)
+ x = block2(x, t)
+ x = attn(x)
+ x = upsample(x)
+
+ x1 = torch.cat((x, r), dim=1)
+ x1 = self.final_res_block(x1, t)
+ x1 = self.final_conv(x1)
+
+ x = x2
+ for (block1, block2, attn, upsample), relation_layer in zip(self.ups2, self.relation_layers_up2):
+ x = torch.cat((x, h2.pop()), dim = 1)
+ x = block1(x, t)
+ x = relation_layer(hm2.pop(), x)
+ x = torch.cat((x, h2.pop()), dim = 1)
+ x = block2(x, t)
+ x = attn(x)
+ x = upsample(x)
+
+ x2 = torch.cat((x, r), dim=1)
+ x2 = self.final_res_block2(x2, t)
+ x2 = self.final_conv2(x2)
+ # sigma = time.reshape(x1.shape[0], *((1,) * (len(x1.shape) - 1)))
+ # scale_C = torch.exp(sigma)
+ x1 = c_skip1 * x_clone + c_out1 * x1
+ x2 = c_skip2 * x_clone + c_out2 * x2
+ return x1, x2
+
+
+if __name__ == "__main__":
+ # resnet = resnet101(weights=ResNet101_Weights)
+ # effnet = efficientnet_b7(weights=EfficientNet_B7_Weights)
+ # effnet = efficientnet_b7(weights=None)
+ # x = torch.rand(1, 3, 320, 320)
+ # y = effnet(x)
+ model = Unet(dim=128, dim_mults=(1, 2, 4, 4),
+ cond_dim=128,
+ cond_dim_mults=(2, 4, ),
+ channels=1,
+ window_sizes1=[[8, 8], [4, 4], [2, 2], [1, 1]],
+ window_sizes2=[[8, 8], [4, 4], [2, 2], [1, 1]],
+ cfg=fvcore.common.config.CfgNode({'cond_pe': False, 'input_size': [80, 80],
+ 'cond_feature_size': (32, 128), 'cond_net': 'vgg',
+ 'num_pos_feats': 96})
+ )
+ x = torch.rand(1, 1, 80, 80)
+ mask = torch.rand(1, 3, 320, 320)
+ time = torch.tensor([0.5124])
+ with torch.no_grad():
+ y = model(x, time, mask)
+ pass
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db3d379f0e44a885f7e4a249b1014172b95264d
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py
@@ -0,0 +1,103 @@
+import torch
+from torch import nn as nn
+from torch.nn import Parameter
+
+
+def weight_quantization(b):
+ def uniform_quant(x, b):
+ xdiv = x.mul((2 ** b - 1))
+ xhard = xdiv.round().div(2 ** b - 1)
+ return xhard
+
+ class _pq(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, alpha):
+ input.div_(alpha) # weights are first divided by alpha
+ input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1]
+ sign = input_c.sign()
+ input_abs = input_c.abs()
+ input_q = uniform_quant(input_abs, b).mul(sign)
+ ctx.save_for_backward(input, input_q)
+ input_q = input_q.mul(alpha) # rescale to the original range
+ return input_q
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone() # grad for weights will not be clipped
+ input, input_q = ctx.saved_tensors
+ i = (input.abs() > 1.).float()
+ sign = input.sign()
+ grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
+ return grad_input, grad_alpha
+
+ return _pq().apply
+
+
+class weight_quantize_fn(nn.Module):
+ def __init__(self, bit_w):
+ super(weight_quantize_fn, self).__init__()
+ assert bit_w > 0
+
+ self.bit_w = bit_w - 1
+ self.weight_q = weight_quantization(b=self.bit_w)
+ self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True))
+
+ def forward(self, weight):
+ mean = weight.data.mean()
+ std = weight.data.std()
+ weight = weight.add(-mean).div(std) # weights normalization
+ weight_q = self.weight_q(weight, self.w_alpha)
+ return weight_q
+
+ def change_bit(self, bit_w):
+ self.bit_w = bit_w - 1
+ self.weight_q = weight_quantization(b=self.bit_w)
+
+def act_quantization(b, signed=False):
+ def uniform_quant(x, b=3):
+ xdiv = x.mul(2 ** b - 1)
+ xhard = xdiv.round().div(2 ** b - 1)
+ return xhard
+
+ class _uq(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, alpha):
+ input = input.div(alpha)
+ input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1)
+ input_q = uniform_quant(input_c, b)
+ ctx.save_for_backward(input, input_q)
+ input_q = input_q.mul(alpha)
+ return input_q
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ input, input_q = ctx.saved_tensors
+ i = (input.abs() > 1.).float()
+ sign = input.sign()
+ grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
+ grad_input = grad_input * (1 - i)
+ return grad_input, grad_alpha
+
+ return _uq().apply
+
+class act_quantize_fn(nn.Module):
+ def __init__(self, bit_a, signed=False):
+ super(act_quantize_fn, self).__init__()
+ self.bit_a = bit_a
+ self.signed = signed
+ if signed:
+ self.bit_a -= 1
+ assert bit_a > 0
+
+ self.act_q = act_quantization(b=self.bit_a, signed=signed)
+ self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True))
+
+ def forward(self, x):
+ return self.act_q(x, self.a_alpha)
+
+ def change_bit(self, bit_a):
+ self.bit_a = bit_a
+ if self.signed:
+ self.bit_a -= 1
+ self.act_q = act_quantization(b=self.bit_a, signed=self.signed)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..256a73a06e32b1cc07d35a8dfcfb3e4601ceeb38
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py
@@ -0,0 +1,963 @@
+from functools import partial
+from typing import Type, Any, Callable, Union, List, Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from torchvision.transforms._presets import ImageClassification
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import WeightsEnum, Weights
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param
+
+
+__all__ = [
+ "ResNet",
+ "ResNet18_Weights",
+ "ResNet34_Weights",
+ "ResNet50_Weights",
+ "ResNet101_Weights",
+ "ResNet152_Weights",
+ "ResNeXt50_32X4D_Weights",
+ "ResNeXt101_32X8D_Weights",
+ "ResNeXt101_64X4D_Weights",
+ "Wide_ResNet50_2_Weights",
+ "Wide_ResNet101_2_Weights",
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "resnet152",
+ "resnext50_32x4d",
+ "resnext101_32x8d",
+ "resnext101_64x4d",
+ "wide_resnet50_2",
+ "wide_resnet101_2",
+]
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.0)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ _log_api_usage_once(self)
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
+ )
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ feats = []
+ x = self.layer1(x)
+ feats.append(x)
+ x = self.layer2(x)
+ feats.append(x)
+ x = self.layer3(x)
+ feats.append(x)
+ x = self.layer4(x)
+ feats.append(x)
+
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ # x = self.fc(x)
+
+ return feats
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ weights: Optional[WeightsEnum],
+ progress: bool,
+ **kwargs: Any,
+) -> ResNet:
+ if weights is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+ model = ResNet(block, layers, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)
+
+ return model
+
+
+_COMMON_META = {
+ "min_size": (1, 1),
+ "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class ResNet18_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 11689512,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 69.758,
+ "acc@5": 89.078,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class ResNet34_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnet34-b627a593.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 21797672,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 73.314,
+ "acc@5": 91.420,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class ResNet50_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 25557032,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 76.130,
+ "acc@5": 92.862,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 25557032,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 80.858,
+ "acc@5": 95.434,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class ResNet101_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 44549160,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 77.374,
+ "acc@5": 93.546,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 44549160,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 81.886,
+ "acc@5": 95.780,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class ResNet152_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 60192808,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 78.312,
+ "acc@5": 94.046,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 60192808,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 82.284,
+ "acc@5": 96.002,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class ResNeXt50_32X4D_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 25028904,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 77.618,
+ "acc@5": 93.698,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 25028904,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 81.198,
+ "acc@5": 95.340,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class ResNeXt101_32X8D_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 88791336,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 79.312,
+ "acc@5": 94.526,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 88791336,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 82.834,
+ "acc@5": 96.228,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class ResNeXt101_64X4D_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 83455272,
+ "recipe": "https://github.com/pytorch/vision/pull/5935",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 83.246,
+ "acc@5": 96.454,
+ }
+ },
+ "_docs": """
+ These weights were trained from scratch by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class Wide_ResNet50_2_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 68883240,
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 78.468,
+ "acc@5": 94.086,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 68883240,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 81.602,
+ "acc@5": 95.758,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+class Wide_ResNet101_2_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 126886696,
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 78.848,
+ "acc@5": 94.284,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+ },
+ )
+ IMAGENET1K_V2 = Weights(
+ url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+ meta={
+ **_COMMON_META,
+ "num_params": 126886696,
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 82.510,
+ "acc@5": 96.020,
+ }
+ },
+ "_docs": """
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
+ `_.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V2
+
+
+@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
+def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
+ """ResNet-18 from `Deep Residual Learning for Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet18_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet18_Weights
+ :members:
+ """
+ weights = ResNet18_Weights.verify(weights)
+
+ return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
+def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
+ """ResNet-34 from `Deep Residual Learning for Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet34_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet34_Weights
+ :members:
+ """
+ weights = ResNet34_Weights.verify(weights)
+
+ return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
+def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
+ """ResNet-50 from `Deep Residual Learning for Image Recognition `__.
+
+ .. note::
+ The bottleneck of TorchVision places the stride for downsampling to the second 3x3
+ convolution while the original paper places it to the first 1x1 convolution.
+ This variant improves the accuracy and is known as `ResNet V1.5
+ `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet50_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet50_Weights
+ :members:
+ """
+ weights = ResNet50_Weights.verify(weights)
+
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
+def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
+ """ResNet-101 from `Deep Residual Learning for Image Recognition `__.
+
+ .. note::
+ The bottleneck of TorchVision places the stride for downsampling to the second 3x3
+ convolution while the original paper places it to the first 1x1 convolution.
+ This variant improves the accuracy and is known as `ResNet V1.5
+ `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet101_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet101_Weights
+ :members:
+ """
+ weights = ResNet101_Weights.verify(weights)
+
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
+def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
+ """ResNet-152 from `Deep Residual Learning for Image Recognition `__.
+
+ .. note::
+ The bottleneck of TorchVision places the stride for downsampling to the second 3x3
+ convolution while the original paper places it to the first 1x1 convolution.
+ This variant improves the accuracy and is known as `ResNet V1.5
+ `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet152_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet152_Weights
+ :members:
+ """
+ weights = ResNet152_Weights.verify(weights)
+
+ return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
+def resnext50_32x4d(
+ *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ResNet:
+ """ResNeXt-50 32x4d model from
+ `Aggregated Residual Transformation for Deep Neural Networks `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNext50_32X4D_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights
+ :members:
+ """
+ weights = ResNeXt50_32X4D_Weights.verify(weights)
+
+ _ovewrite_named_param(kwargs, "groups", 32)
+ _ovewrite_named_param(kwargs, "width_per_group", 4)
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
+def resnext101_32x8d(
+ *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ResNet:
+ """ResNeXt-101 32x8d model from
+ `Aggregated Residual Transformation for Deep Neural Networks `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
+ :members:
+ """
+ weights = ResNeXt101_32X8D_Weights.verify(weights)
+
+ _ovewrite_named_param(kwargs, "groups", 32)
+ _ovewrite_named_param(kwargs, "width_per_group", 8)
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
+
+
+def resnext101_64x4d(
+ *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ResNet:
+ """ResNeXt-101 64x4d model from
+ `Aggregated Residual Transformation for Deep Neural Networks `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
+ :members:
+ """
+ weights = ResNeXt101_64X4D_Weights.verify(weights)
+
+ _ovewrite_named_param(kwargs, "groups", 64)
+ _ovewrite_named_param(kwargs, "width_per_group", 4)
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
+def wide_resnet50_2(
+ *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ResNet:
+ """Wide ResNet-50-2 model from
+ `Wide Residual Networks `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights
+ :members:
+ """
+ weights = Wide_ResNet50_2_Weights.verify(weights)
+
+ _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
+ return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
+def wide_resnet101_2(
+ *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ResNet:
+ """Wide ResNet-101-2 model from
+ `Wide Residual Networks `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048
+ channels, and in Wide ResNet-101-2 has 2048-1024-2048.
+
+ Args:
+ weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+ .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights
+ :members:
+ """
+ weights = Wide_ResNet101_2_Weights.verify(weights)
+
+ _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
+
+
+# The dictionary below is internal implementation detail and will be removed in v0.15
+from torchvision.models._utils import _ModelURLs
+
+
+model_urls = _ModelURLs(
+ {
+ "resnet18": ResNet18_Weights.IMAGENET1K_V1.url,
+ "resnet34": ResNet34_Weights.IMAGENET1K_V1.url,
+ "resnet50": ResNet50_Weights.IMAGENET1K_V1.url,
+ "resnet101": ResNet101_Weights.IMAGENET1K_V1.url,
+ "resnet152": ResNet152_Weights.IMAGENET1K_V1.url,
+ "resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url,
+ "resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url,
+ "wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url,
+ "wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url,
+ }
+)
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9915bfd6df1de65f8a84975afcf22a386d15958e
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py
@@ -0,0 +1,651 @@
+from functools import partial
+from typing import Optional, Callable, List, Any
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from torchvision.ops.misc import MLP, Permute
+from torchvision.ops.stochastic_depth import StochasticDepth
+from torchvision.transforms._presets import ImageClassification, InterpolationMode
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import WeightsEnum, Weights
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import _ovewrite_named_param
+
+
+__all__ = [
+ "SwinTransformer",
+ "Swin_T_Weights",
+ "Swin_S_Weights",
+ "Swin_B_Weights",
+ "swin_t",
+ "swin_s",
+ "swin_b",
+]
+
+
+def _patch_merging_pad(x):
+ H, W, _ = x.shape[-3:]
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+ return x
+
+
+torch.fx.wrap("_patch_merging_pad")
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer.
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ """
+
+ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
+ super().__init__()
+ _log_api_usage_once(self)
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x: Tensor):
+ """
+ Args:
+ x (Tensor): input tensor with expected layout of [..., H, W, C]
+ Returns:
+ Tensor with layout of [..., H/2, W/2, 2*C]
+ """
+ x = _patch_merging_pad(x)
+
+ x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
+ x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
+ x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
+ x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x) # ... H/2 W/2 2*C
+ return x
+
+
+def shifted_window_attention(
+ input: Tensor,
+ qkv_weight: Tensor,
+ proj_weight: Tensor,
+ relative_position_bias: Tensor,
+ window_size: List[int],
+ num_heads: int,
+ shift_size: List[int],
+ attention_dropout: float = 0.0,
+ dropout: float = 0.0,
+ qkv_bias: Optional[Tensor] = None,
+ proj_bias: Optional[Tensor] = None,
+):
+ """
+ Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
+ qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
+ proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
+ relative_position_bias (Tensor): The learned relative position bias added to attention.
+ window_size (List[int]): Window size.
+ num_heads (int): Number of attention heads.
+ shift_size (List[int]): Shift size for shifted window attention.
+ attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
+ dropout (float): Dropout ratio of output. Default: 0.0.
+ qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
+ proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
+ Returns:
+ Tensor[N, H, W, C]: The output tensor after shifted window attention.
+ """
+ B, H, W, C = input.shape
+ # pad feature maps to multiples of window size
+ pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
+ pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
+ x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
+ _, pad_H, pad_W, _ = x.shape
+
+ # If window size is larger than feature size, there is no need to shift window
+ if window_size[0] >= pad_H:
+ shift_size[0] = 0
+ if window_size[1] >= pad_W:
+ shift_size[1] = 0
+
+ # cyclic shift
+ if sum(shift_size) > 0:
+ x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
+
+ # partition windows
+ num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
+ x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
+
+ # multi-head attention
+ qkv = F.linear(x, qkv_weight, qkv_bias)
+ qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q = q * (C // num_heads) ** -0.5
+ attn = q.matmul(k.transpose(-2, -1))
+ # add relative position bias
+ attn = attn + relative_position_bias
+
+ if sum(shift_size) > 0:
+ # generate attention mask
+ attn_mask = x.new_zeros((pad_H, pad_W))
+ h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
+ w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
+ count = 0
+ for h in h_slices:
+ for w in w_slices:
+ attn_mask[h[0] : h[1], w[0] : w[1]] = count
+ count += 1
+ attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
+ attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
+ attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
+ attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, num_heads, x.size(1), x.size(1))
+
+ attn = F.softmax(attn, dim=-1)
+ attn = F.dropout(attn, p=attention_dropout)
+
+ x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
+ x = F.linear(x, proj_weight, proj_bias)
+ x = F.dropout(x, p=dropout)
+
+ # reverse windows
+ x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
+
+ # reverse cyclic shift
+ if sum(shift_size) > 0:
+ x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
+
+ # unpad features
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+torch.fx.wrap("shifted_window_attention")
+
+
+class ShiftedWindowAttention(nn.Module):
+ """
+ See :func:`shifted_window_attention`.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ window_size: List[int],
+ shift_size: List[int],
+ num_heads: int,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attention_dropout: float = 0.0,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ if len(window_size) != 2 or len(shift_size) != 2:
+ raise ValueError("window_size and shift_size must be of length 2")
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.num_heads = num_heads
+ self.attention_dropout = attention_dropout
+ self.dropout = dropout
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ # coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+ def forward(self, x: Tensor):
+ """
+ Args:
+ x (Tensor): Tensor with layout of [B, H, W, C]
+ Returns:
+ Tensor with same layout as input, i.e. [B, H, W, C]
+ """
+
+ N = self.window_size[0] * self.window_size[1]
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
+ relative_position_bias = relative_position_bias.view(N, N, -1)
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
+
+ return shifted_window_attention(
+ x,
+ self.qkv.weight,
+ self.proj.weight,
+ relative_position_bias,
+ self.window_size,
+ self.num_heads,
+ shift_size=self.shift_size,
+ attention_dropout=self.attention_dropout,
+ dropout=self.dropout,
+ qkv_bias=self.qkv.bias,
+ proj_bias=self.proj.bias,
+ )
+
+
+class SwinTransformerBlock(nn.Module):
+ """
+ Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (List[int]): Window size.
+ shift_size (List[int]): Shift size for shifted window attention.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ dropout (float): Dropout rate. Default: 0.0.
+ attention_dropout (float): Attention dropout rate. Default: 0.0.
+ stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ window_size: List[int],
+ shift_size: List[int],
+ mlp_ratio: float = 4.0,
+ dropout: float = 0.0,
+ attention_dropout: float = 0.0,
+ stochastic_depth_prob: float = 0.0,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_layer(
+ dim,
+ window_size,
+ shift_size,
+ num_heads,
+ attention_dropout=attention_dropout,
+ dropout=dropout,
+ )
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
+
+ for m in self.mlp.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.normal_(m.bias, std=1e-6)
+
+ def forward(self, x: Tensor):
+ x = x + self.stochastic_depth(self.attn(self.norm1(x)))
+ x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """
+ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
+ Shifted Windows" `_ paper.
+ Args:
+ patch_size (List[int]): Patch size.
+ embed_dim (int): Patch embedding dimension.
+ depths (List(int)): Depth of each Swin Transformer layer.
+ num_heads (List(int)): Number of attention heads in different layers.
+ window_size (List[int]): Window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ dropout (float): Dropout rate. Default: 0.0.
+ attention_dropout (float): Attention dropout rate. Default: 0.0.
+ stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0.
+ num_classes (int): Number of classes for classification head. Default: 1000.
+ block (nn.Module, optional): SwinTransformer Block. Default: None.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None.
+ """
+
+ def __init__(
+ self,
+ patch_size: List[int],
+ embed_dim: int,
+ depths: List[int],
+ num_heads: List[int],
+ window_size: List[int],
+ mlp_ratio: float = 4.0,
+ dropout: float = 0.0,
+ attention_dropout: float = 0.0,
+ stochastic_depth_prob: float = 0.0,
+ num_classes: int = 1000,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ block: Optional[Callable[..., nn.Module]] = None,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+ self.num_classes = num_classes
+
+ if block is None:
+ block = SwinTransformerBlock
+
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-5)
+
+ layers: List[nn.Module] = []
+ # split image into non-overlapping patches
+ # layers.append(
+ # nn.Sequential(
+ # nn.Conv2d(
+ # 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
+ # ),
+ # Permute([0, 2, 3, 1]),
+ # norm_layer(embed_dim),
+ # )
+ # )
+ self.first_coonv = nn.Sequential(
+ nn.Conv2d(
+ 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
+ ),
+ Permute([0, 2, 3, 1]),
+ norm_layer(embed_dim),
+ )
+
+ total_stage_blocks = sum(depths)
+ stage_block_id = 0
+ # build SwinTransformer blocks
+ for i_stage in range(len(depths)):
+ stage: List[nn.Module] = []
+ dim = embed_dim * 2 ** i_stage
+ for i_layer in range(depths[i_stage]):
+ # adjust stochastic depth probability based on the depth of the stage block
+ sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
+ stage.append(
+ block(
+ dim,
+ num_heads[i_stage],
+ window_size=window_size,
+ shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
+ mlp_ratio=mlp_ratio,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ stochastic_depth_prob=sd_prob,
+ norm_layer=norm_layer,
+ )
+ )
+ stage_block_id += 1
+ layers.append(nn.Sequential(*stage))
+ # add patch merging layer
+ if i_stage < (len(depths) - 1):
+ layers.append(PatchMerging(dim, norm_layer))
+ # self.features = nn.Sequential(*layers)
+ self.features = nn.ModuleList(layers)
+
+ num_features = embed_dim * 2 ** (len(depths) - 1)
+ self.norm = norm_layer(num_features)
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.head = nn.Linear(num_features, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ feats = []
+ x = self.first_coonv(x)
+ for i, layer in enumerate(self.features):
+ x = layer(x)
+ if i in [0, 2, 4, 6]:
+ feats.append(x.permute(0, 3, 1, 2).contiguous())
+ # x = self.features(x)
+ # x = self.norm(x)
+ # x = x.permute(0, 3, 1, 2)
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ # x = self.head(x)
+ return feats
+
+
+def _swin_transformer(
+ patch_size: List[int],
+ embed_dim: int,
+ depths: List[int],
+ num_heads: List[int],
+ window_size: List[int],
+ stochastic_depth_prob: float,
+ weights: Optional[WeightsEnum],
+ progress: bool,
+ **kwargs: Any,
+) -> SwinTransformer:
+ if weights is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+ model = SwinTransformer(
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ depths=depths,
+ num_heads=num_heads,
+ window_size=window_size,
+ stochastic_depth_prob=stochastic_depth_prob,
+ **kwargs,
+ )
+
+ if weights is not None:
+ ckpt1 = weights.get_state_dict(progress=progress)
+ ckpt2 = model.state_dict()
+ kl1 = list(ckpt1.keys())
+ for i, k in enumerate(list(ckpt2.keys())):
+ ckpt2[k] = ckpt1[kl1[i]]
+ msg = model.load_state_dict(ckpt2, strict=False)
+ print(f'Load swin_transformer: {msg}')
+
+ return model
+
+
+_COMMON_META = {
+ "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class Swin_T_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
+ transforms=partial(
+ ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META,
+ "num_params": 28288354,
+ "min_size": (224, 224),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 81.474,
+ "acc@5": 95.776,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class Swin_S_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
+ transforms=partial(
+ ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META,
+ "num_params": 49606258,
+ "min_size": (224, 224),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 83.196,
+ "acc@5": 96.360,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class Swin_B_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
+ transforms=partial(
+ ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
+ ),
+ meta={
+ **_COMMON_META,
+ "num_params": 87768224,
+ "min_size": (224, 224),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 83.582,
+ "acc@5": 96.640,
+ }
+ },
+ "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
+ """
+ Constructs a swin_tiny architecture from
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_.
+
+ Args:
+ weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.Swin_T_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.Swin_T_Weights
+ :members:
+ """
+ weights = Swin_T_Weights.verify(weights)
+
+ return _swin_transformer(
+ patch_size=[4, 4],
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=[7, 7],
+ stochastic_depth_prob=0.2,
+ weights=weights,
+ progress=progress,
+ **kwargs,
+ )
+
+
+def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
+ """
+ Constructs a swin_small architecture from
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_.
+
+ Args:
+ weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.Swin_S_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.Swin_S_Weights
+ :members:
+ """
+ weights = Swin_S_Weights.verify(weights)
+
+ return _swin_transformer(
+ patch_size=[4, 4],
+ embed_dim=96,
+ depths=[2, 2, 18, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=[7, 7],
+ stochastic_depth_prob=0.3,
+ weights=weights,
+ progress=progress,
+ **kwargs,
+ )
+
+
+from torchvision.models._utils import handle_legacy_interface
+@handle_legacy_interface(weights=("pretrained", Swin_B_Weights.IMAGENET1K_V1))
+def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
+ """
+ Constructs a swin_base architecture from
+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_.
+
+ Args:
+ weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.Swin_B_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.Swin_B_Weights
+ :members:
+ """
+ weights = Swin_B_Weights.verify(weights)
+
+ return _swin_transformer(
+ patch_size=[4, 4],
+ embed_dim=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=[7, 7],
+ stochastic_depth_prob=0.5,
+ weights=weights,
+ progress=progress,
+ **kwargs,
+ )
+
+if __name__ == '__main__':
+ model = swin_b(weights=Swin_B_Weights)
+ x = torch.rand(1, 3, 320, 320)
+ y = model(x)
+ pause = 0
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6e6dd3f7e1e8a69ba9e36009b5a1088cd82726f
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py
@@ -0,0 +1,376 @@
+import math
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from einops import rearrange, reduce
+from functools import partial
+
+
+def exists(x):
+ return x is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+def identity(t, *args, **kwargs):
+ return t
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def has_int_squareroot(num):
+ return (math.sqrt(num) ** 2) == num
+
+def num_to_groups(num, divisor):
+ groups = num // divisor
+ remainder = num % divisor
+ arr = [divisor] * groups
+ if remainder > 0:
+ arr.append(remainder)
+ return arr
+
+def convert_image_to_fn(img_type, image):
+ if image.mode != img_type:
+ return image.convert(img_type)
+ return image
+
+# normalization functions
+
+def normalize_to_neg_one_to_one(img):
+ return img * 2 - 1
+
+def unnormalize_to_zero_to_one(t):
+ return (t + 1) * 0.5
+
+# small helper modules
+
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, *args, **kwargs):
+ return self.fn(x, *args, **kwargs) + x
+
+def Upsample(dim, dim_out = None):
+ return nn.Sequential(
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
+ nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
+ )
+
+def Downsample(dim, dim_out = None):
+ return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
+
+class WeightStandardizedConv2d(nn.Conv2d):
+ """
+ https://arxiv.org/abs/1903.10520
+ weight standardization purportedly works synergistically with group normalization
+ """
+ def forward(self, x):
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+
+ weight = self.weight
+ mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
+ var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
+ normalized_weight = (weight - mean) * (var + eps).rsqrt()
+
+ return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
+
+ def forward(self, x):
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+ mean = torch.mean(x, dim = 1, keepdim = True)
+ return (x - mean) * (var + eps).rsqrt() * self.g
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.fn = fn
+ self.norm = LayerNorm(dim)
+
+ def forward(self, x):
+ x = self.norm(x)
+ return self.fn(x)
+
+# sinusoidal positional embeds
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+class RandomOrLearnedSinusoidalPosEmb(nn.Module):
+ """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
+
+ def __init__(self, dim, is_random = False):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
+
+ def forward(self, x):
+ x = rearrange(x, 'b -> b 1')
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
+ fouriered = torch.cat((x, fouriered), dim = -1)
+ return fouriered
+
+# building block modules
+
+class Block(nn.Module):
+ def __init__(self, dim, dim_out, groups = 8):
+ super().__init__()
+ self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
+ self.norm = nn.GroupNorm(groups, dim_out)
+ self.act = nn.SiLU()
+
+ def forward(self, x, scale_shift = None):
+ x = self.proj(x)
+ x = self.norm(x)
+
+ if exists(scale_shift):
+ scale, shift = scale_shift
+ x = x * (scale + 1) + shift
+
+ x = self.act(x)
+ return x
+
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_emb_dim, dim_out * 2)
+ ) if exists(time_emb_dim) else None
+
+ self.block1 = Block(dim, dim_out, groups = groups)
+ self.block2 = Block(dim_out, dim_out, groups = groups)
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, time_emb = None):
+
+ scale_shift = None
+ if exists(self.mlp) and exists(time_emb):
+ time_emb = self.mlp(time_emb)
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
+ scale_shift = time_emb.chunk(2, dim = 1)
+
+ h = self.block1(x, scale_shift = scale_shift)
+
+ h = self.block2(h)
+
+ return h + self.res_conv(x)
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads = 4, dim_head = 32):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv2d(hidden_dim, dim, 1),
+ LayerNorm(dim)
+ )
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
+
+ q = q.softmax(dim = -2)
+ k = k.softmax(dim = -1)
+
+ q = q * self.scale
+ v = v / (h * w)
+
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
+
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
+ return self.to_out(out)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 4, dim_head = 32):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x).chunk(3, dim = 1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
+
+ q = q * self.scale
+
+ sim = einsum('b h d i, b h d j -> b h i j', q, k)
+ attn = sim.softmax(dim = -1)
+ out = einsum('b h i j, b h d j -> b h i d', attn, v)
+
+ out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
+ return self.to_out(out)
+
+# model
+
+class Unet(nn.Module):
+ def __init__(
+ self,
+ dim,
+ init_dim = None,
+ out_dim = None,
+ dim_mults=(1, 2, 4, 8),
+ channels = 3,
+ self_condition = False,
+ resnet_block_groups = 8,
+ heads=8,
+ learned_variance = False,
+ learned_sinusoidal_cond = False,
+ random_fourier_features = False,
+ learned_sinusoidal_dim = 16,
+ out_mul=1,
+ ):
+ super().__init__()
+
+ # determine dimensions
+
+ self.channels = channels
+ self.self_condition = self_condition
+ input_channels = channels * (2 if self_condition else 1)
+
+ init_dim = default(init_dim, dim)
+ self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
+
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
+ in_out = list(zip(dims[:-1], dims[1:]))
+
+ block_klass = partial(ResnetBlock, groups = resnet_block_groups)
+
+ # time embeddings
+
+ time_dim = dim * 4
+
+ self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
+
+ if self.random_or_learned_sinusoidal_cond:
+ sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
+ fourier_dim = learned_sinusoidal_dim + 1
+ else:
+ sinu_pos_emb = SinusoidalPosEmb(dim)
+ fourier_dim = dim
+
+ self.time_mlp = nn.Sequential(
+ sinu_pos_emb,
+ nn.Linear(fourier_dim, time_dim),
+ nn.GELU(),
+ nn.Linear(time_dim, time_dim)
+ )
+
+ # layers
+
+ self.downs = nn.ModuleList([])
+ self.ups = nn.ModuleList([])
+ num_resolutions = len(in_out)
+
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (num_resolutions - 1)
+
+ self.downs.append(nn.ModuleList([
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
+ block_klass(dim_in, dim_in, time_emb_dim = time_dim),
+ Residual(PreNorm(dim_in, LinearAttention(dim_in, heads=heads))),
+ Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
+ ]))
+
+ mid_dim = dims[-1]
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
+ self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, heads=heads)))
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
+
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
+ is_last = ind == (len(in_out) - 1)
+
+ self.ups.append(nn.ModuleList([
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
+ block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
+ Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
+ ]))
+
+ default_out_dim = channels * out_mul
+ self.out_dim = default(out_dim, default_out_dim)
+
+ self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
+ self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
+
+ def forward(self, x, time, cond=None, x_self_cond=None): ## cond is always None for unconditional model
+ if self.self_condition:
+ x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
+ x = torch.cat((x_self_cond, x), dim = 1)
+
+ x = self.init_conv(x)
+ r = x.clone()
+
+ t = self.time_mlp(time)
+
+ h = []
+
+ for block1, block2, attn, downsample in self.downs:
+ x = block1(x, t)
+ h.append(x)
+
+ x = block2(x, t)
+ x = attn(x)
+ h.append(x)
+
+ x = downsample(x)
+
+ x = self.mid_block1(x, t)
+ x = self.mid_attn(x)
+ x = self.mid_block2(x, t)
+
+ for block1, block2, attn, upsample in self.ups:
+ x = torch.cat((x, h.pop()), dim = 1)
+ x = block1(x, t)
+
+ x = torch.cat((x, h.pop()), dim = 1)
+ x = block2(x, t)
+ x = attn(x)
+
+ x = upsample(x)
+
+ x = torch.cat((x, r), dim = 1)
+
+ x = self.final_res_block(x, t)
+ return self.final_conv(x)
+
+if __name__ == '__main__':
+ model = Unet(96, out_mul=2, dim_mults=[1,2,4,8], heads=8)
+ x = torch.rand(2, 3, 8, 8)
+ time = torch.tensor([2, 5])
+ with torch.no_grad():
+ y = model(x, time)
+ pass
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4975a5de594ac8fed9370d95fe978de0b36fc230
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py
@@ -0,0 +1,68 @@
+import os
+from pathlib import Path
+import time
+import logging
+import math
+
+def create_logger(root_dir, des=''):
+ root_output_dir = Path(root_dir)
+ # set up logger
+ if not root_output_dir.exists():
+ print('=> creating {}'.format(root_output_dir))
+ root_output_dir.mkdir(exist_ok=True, parents=True)
+ time_str = time.strftime('%Y-%m-%d-%H-%M')
+ log_file = '{}_{}.log'.format(time_str, des)
+ final_log_file = root_output_dir / log_file
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(filename=str(final_log_file), format=head)
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+ console = logging.StreamHandler()
+ logging.getLogger('').addHandler(console)
+ return logger
+
+def exists(x):
+ return x is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+def identity(t, *args, **kwargs):
+ return t
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def has_int_squareroot(num):
+ return (math.sqrt(num) ** 2) == num
+
+def num_to_groups(num, divisor):
+ groups = num // divisor
+ remainder = num % divisor
+ arr = [divisor] * groups
+ if remainder > 0:
+ arr.append(remainder)
+ return arr
+
+def convert_image_to_fn(img_type, image):
+ if image.mode != img_type:
+ return image.convert(img_type)
+ return image
+
+# normalization functions
+
+def normalize_to_neg_one_to_one(img):
+ return img * 2 - 1
+
+def unnormalize_to_zero_to_one(t):
+ return (t + 1) * 0.5
+
+def dict2str(dict):
+ s = ''
+ for k, v in dict.items():
+ s += "{}: {:.5f}, ".format(k, v)
+ return s
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..e33c626885864748ea395b0db73aa500edde05d7
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py
@@ -0,0 +1,517 @@
+from functools import partial
+from typing import Union, List, Dict, Any, Optional, cast
+
+import torch
+import torch.nn as nn
+
+from torchvision.transforms._presets import ImageClassification
+from torchvision.utils import _log_api_usage_once
+from torchvision.models._api import WeightsEnum, Weights
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param
+
+
+__all__ = [
+ "VGG",
+ "VGG11_Weights",
+ "VGG11_BN_Weights",
+ "VGG13_Weights",
+ "VGG13_BN_Weights",
+ "VGG16_Weights",
+ "VGG16_BN_Weights",
+ "VGG19_Weights",
+ "VGG19_BN_Weights",
+ "vgg11",
+ "vgg11_bn",
+ "vgg13",
+ "vgg13_bn",
+ "vgg16",
+ "vgg16_bn",
+ "vgg19",
+ "vgg19_bn",
+]
+
+
+class VGG(nn.Module):
+ def __init__(
+ self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
+ ) -> None:
+ super().__init__()
+ _log_api_usage_once(self)
+ self.features = features
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(p=dropout),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(p=dropout),
+ nn.Linear(4096, num_classes),
+ )
+ if init_weights:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ feats = []
+ # x = self.features(x)
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ # x = self.classifier(x)
+ for i, layer in enumerate(self.features):
+ x = layer(x)
+ if i in [9, 16, 23, 30]:
+ feats.append(x)
+ return feats
+
+
+def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
+ layers: List[nn.Module] = []
+ in_channels = 3
+ for v in cfg:
+ if v == "M":
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ else:
+ v = cast(int, v)
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+ if batch_norm:
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+ else:
+ layers += [conv2d, nn.ReLU(inplace=True)]
+ in_channels = v
+ return nn.ModuleList(layers)
+
+
+cfgs: Dict[str, List[Union[str, int]]] = {
+ "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
+ "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
+ "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
+ "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
+}
+
+
+def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
+ if weights is not None:
+ kwargs["init_weights"] = False
+ if weights.meta["categories"] is not None:
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
+
+ if weights is not None:
+ ckpt1 = weights.get_state_dict(progress=progress)
+ ckpt2 = model.state_dict()
+ kl1 = list(ckpt1.keys())
+ for i, k in enumerate(list(ckpt2.keys())):
+ ckpt2[k] = ckpt1[kl1[i]]
+ msg = model.load_state_dict(ckpt2, strict=False)
+ print(f'Load VGG: {msg}')
+ else:
+ print('No pretrained weight loaded!')
+ return model
+
+
+_COMMON_META = {
+ "min_size": (32, 32),
+ "categories": _IMAGENET_CATEGORIES,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
+ "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
+}
+
+
+class VGG11_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg11-8a719046.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 132863336,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 69.020,
+ "acc@5": 88.628,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG11_BN_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 132868840,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 70.370,
+ "acc@5": 89.810,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG13_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg13-19584684.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 133047848,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 69.928,
+ "acc@5": 89.246,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG13_BN_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 133053736,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.586,
+ "acc@5": 90.374,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG16_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg16-397923af.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 138357544,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 71.592,
+ "acc@5": 90.382,
+ }
+ },
+ },
+ )
+ IMAGENET1K_FEATURES = Weights(
+ # Weights ported from https://github.com/amdegroot/ssd.pytorch/
+ url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
+ transforms=partial(
+ ImageClassification,
+ crop_size=224,
+ mean=(0.48235, 0.45882, 0.40784),
+ std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
+ ),
+ meta={
+ **_COMMON_META,
+ "num_params": 138357544,
+ "categories": None,
+ "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": float("nan"),
+ "acc@5": float("nan"),
+ }
+ },
+ "_docs": """
+ These weights can't be used for classification because they are missing values in the `classifier`
+ module. Only the `features` module has valid values and can be used for feature extraction. The weights
+ were trained using the original input standardization method as described in the paper.
+ """,
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG16_BN_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 138365992,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 73.360,
+ "acc@5": 91.516,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG19_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 143667240,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 72.376,
+ "acc@5": 90.876,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+class VGG19_BN_Weights(WeightsEnum):
+ IMAGENET1K_V1 = Weights(
+ url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
+ transforms=partial(ImageClassification, crop_size=224),
+ meta={
+ **_COMMON_META,
+ "num_params": 143678248,
+ "_metrics": {
+ "ImageNet-1K": {
+ "acc@1": 74.218,
+ "acc@5": 91.842,
+ }
+ },
+ },
+ )
+ DEFAULT = IMAGENET1K_V1
+
+
+@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
+def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG11_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG11_Weights
+ :members:
+ """
+ weights = VGG11_Weights.verify(weights)
+
+ return _vgg("A", False, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
+def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG11_BN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG11_BN_Weights
+ :members:
+ """
+ weights = VGG11_BN_Weights.verify(weights)
+
+ return _vgg("A", True, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
+def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG13_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG13_Weights
+ :members:
+ """
+ weights = VGG13_Weights.verify(weights)
+
+ return _vgg("B", False, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
+def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG13_BN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG13_BN_Weights
+ :members:
+ """
+ weights = VGG13_BN_Weights.verify(weights)
+
+ return _vgg("B", True, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
+def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG16_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG16_Weights
+ :members:
+ """
+ weights = VGG16_Weights.verify(weights)
+
+ return _vgg("D", False, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
+def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG16_BN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG16_BN_Weights
+ :members:
+ """
+ weights = VGG16_BN_Weights.verify(weights)
+
+ return _vgg("D", True, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
+def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG19_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG19_Weights
+ :members:
+ """
+ weights = VGG19_Weights.verify(weights)
+
+ return _vgg("E", False, weights, progress, **kwargs)
+
+
+@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
+def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
+ """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.VGG19_BN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.VGG19_BN_Weights
+ :members:
+ """
+ weights = VGG19_BN_Weights.verify(weights)
+
+ return _vgg("E", True, weights, progress, **kwargs)
+
+
+# The dictionary below is internal implementation detail and will be removed in v0.15
+from torchvision.models._utils import _ModelURLs
+
+
+model_urls = _ModelURLs(
+ {
+ "vgg11": VGG11_Weights.IMAGENET1K_V1.url,
+ "vgg13": VGG13_Weights.IMAGENET1K_V1.url,
+ "vgg16": VGG16_Weights.IMAGENET1K_V1.url,
+ "vgg19": VGG19_Weights.IMAGENET1K_V1.url,
+ "vgg11_bn": VGG11_BN_Weights.IMAGENET1K_V1.url,
+ "vgg13_bn": VGG13_BN_Weights.IMAGENET1K_V1.url,
+ "vgg16_bn": VGG16_BN_Weights.IMAGENET1K_V1.url,
+ "vgg19_bn": VGG19_BN_Weights.IMAGENET1K_V1.url,
+ }
+)
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b051a3e10aaa33bdb8852a842229ffaaba68df0a
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py
@@ -0,0 +1,83 @@
+import pywt
+import pywt.data
+import torch
+from torch import nn
+from torch.autograd import Function
+import torch.nn.functional as F
+
+
+def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
+ w = pywt.Wavelet(wave)
+ dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
+ dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
+ dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
+ dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
+ dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
+ dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
+
+ dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
+
+ rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
+ rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
+ rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
+ rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
+ rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
+ rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
+
+ rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
+
+ return dec_filters, rec_filters
+
+
+def wt(x, filters, in_size, level):
+ _, _, h, w = x.shape
+ pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
+ res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad)
+ if level > 1:
+ res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1)
+ res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w)
+ return res
+
+
+def iwt(x, inv_filters, in_size, level):
+ _, _, h, w = x.shape
+ pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1)
+ res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2)
+ if level > 1:
+ res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1)
+ res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad)
+ return res
+
+
+def get_inverse_transform(weights, in_size, level):
+ class InverseWaveletTransform(Function):
+
+ @staticmethod
+ def forward(ctx, input):
+ with torch.no_grad():
+ x = iwt(input, weights, in_size, level)
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad = wt(grad_output, weights, in_size, level)
+ return grad, None
+
+ return InverseWaveletTransform().apply
+
+
+def get_transform(weights, in_size, level):
+ class WaveletTransform(Function):
+
+ @staticmethod
+ def forward(ctx, input):
+ with torch.no_grad():
+ x = wt(input, weights, in_size, level)
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad = iwt(grad_output, weights, in_size, level)
+ return grad, None
+
+ return WaveletTransform().apply
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5faeb37b717f47a9f526156ef77ac2711a2eb738
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py
@@ -0,0 +1,101 @@
+from typing import Union, Tuple
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.quantization import weight_quantize_fn, act_quantize_fn
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch import wavelet
+
+
+class WCC(nn.Conv1d):
+ def __init__(self, in_channels: int,
+ out_channels: int,
+ stride: Union[int, Tuple] = 1,
+ padding: Union[int, Tuple] = 0,
+ dilation: Union[int, Tuple] = 1,
+ groups: int = 1,
+ bias: bool = False,
+ levels: int = 3,
+ compress_rate: float = 0.25,
+ bit_w: int = 8,
+ bit_a: int = 8,
+ wt_type: str = "db1"):
+ super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
+ self.layer_type = 'WCC'
+ self.bit_w = bit_w
+ self.bit_a = bit_a
+
+ self.weight_quant = weight_quantize_fn(self.bit_w)
+ self.act_quant = act_quantize_fn(self.bit_a, signed=True)
+
+ self.levels = levels
+ self.wt_type = wt_type
+ self.compress_rate = compress_rate
+
+ dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
+ in_size=in_channels,
+ out_size=out_channels)
+ self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
+ self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
+ self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels)
+ self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels)
+
+ self.get_pad = lambda n: ((2 ** levels) - n) % (2 ** levels)
+
+ def forward(self, x):
+ in_shape = x.shape
+ pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3]))
+ x = F.pad(x, pads) # pad to match 2^(levels)
+
+ weight_q = self.weight_quant(self.weight) # quantize weights
+ x = self.wt(x) # H
+ topk, ids = self.compress(x) # T
+ topk_q = self.act_quant(topk) # quantize activations
+ topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1
+ x = self.decompress(topk_q, ids, x.shape) # T^T
+ x = self.iwt(x) # H^T
+
+ x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads
+ return x
+
+ def compress(self, x):
+ b, c, h, w = x.shape
+ acc = x.norm(dim=1).pow(2)
+ acc = acc.view(b, h * w)
+ k = int(h * w * self.compress_rate)
+ ids = acc.topk(k, dim=1, sorted=False)[1]
+ ids.unsqueeze_(dim=1)
+ topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1))
+ return topk, ids
+
+ def decompress(self, topk, ids, shape):
+ b, _, h, w = shape
+ ids = ids.repeat(1, self.out_channels, 1)
+ x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device)
+ x = x.scatter(dim=2, index=ids, src=topk)
+ x = x.reshape((b, self.out_channels, h, w))
+ return x
+
+ def change_wt_params(self, compress_rate, levels, wt_type="db1"):
+ self.compress_rate = compress_rate
+ self.levels = levels
+ dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
+ in_size=self.in_channels,
+ out_size=self.out_channels)
+ self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
+ self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
+ self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels)
+ self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels)
+
+ def change_bit(self, bit_w, bit_a):
+ self.bit_w = bit_w
+ self.bit_a = bit_a
+ self.weight_quant.change_bit(bit_w)
+ self.act_quant.change_bit(bit_a)
+
+if __name__ == '__main__':
+ wcc = WCC(80, 80)
+ x = torch.rand(1, 80, 80, 80)
+ y = wcc(x)
+ pause = 0
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/model.py b/src/custom_controlnet_aux/diffusion_edge/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3c86b512503ba681ed30cb894ec937ba188ee4e
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/model.py
@@ -0,0 +1,197 @@
+import numpy as np
+import yaml
+import argparse
+import math
+import torch
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import *
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL
+# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet
+from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import *
+from fvcore.common.config import CfgNode
+from pathlib import Path
+
+def load_conf(config_file, conf={}):
+ with open(config_file) as f:
+ exp_conf = yaml.load(f, Loader=yaml.FullLoader)
+ for k, v in exp_conf.items():
+ conf[k] = v
+ return conf
+
+def prepare_args(ckpt_path, sampling_timesteps=1):
+ return argparse.Namespace(
+ cfg=load_conf(Path(__file__).parent / "default.yaml"),
+ pre_weight=ckpt_path,
+ sampling_timesteps=sampling_timesteps
+ )
+
+class DiffusionEdge:
+ def __init__(self, args) -> None:
+ self.cfg = CfgNode(args.cfg)
+ torch.manual_seed(42)
+ np.random.seed(42)
+ model_cfg = self.cfg.model
+ first_stage_cfg = model_cfg.first_stage
+ first_stage_model = AutoencoderKL(
+ ddconfig=first_stage_cfg.ddconfig,
+ lossconfig=first_stage_cfg.lossconfig,
+ embed_dim=first_stage_cfg.embed_dim,
+ ckpt_path=first_stage_cfg.ckpt_path,
+ )
+ if model_cfg.model_name == 'cond_unet':
+ from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet
+ unet_cfg = model_cfg.unet
+ unet = Unet(dim=unet_cfg.dim,
+ channels=unet_cfg.channels,
+ dim_mults=unet_cfg.dim_mults,
+ learned_variance=unet_cfg.get('learned_variance', False),
+ out_mul=unet_cfg.out_mul,
+ cond_in_dim=unet_cfg.cond_in_dim,
+ cond_dim=unet_cfg.cond_dim,
+ cond_dim_mults=unet_cfg.cond_dim_mults,
+ window_sizes1=unet_cfg.window_sizes1,
+ window_sizes2=unet_cfg.window_sizes2,
+ fourier_scale=unet_cfg.fourier_scale,
+ cfg=unet_cfg,
+ )
+ else:
+ raise NotImplementedError
+ if model_cfg.model_type == 'const_sde':
+ from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion
+ else:
+ raise NotImplementedError(f'{model_cfg.model_type} is not surportted !')
+
+ self.model = LatentDiffusion(
+ model=unet,
+ auto_encoder=first_stage_model,
+ train_sample=model_cfg.train_sample,
+ image_size=model_cfg.image_size,
+ timesteps=model_cfg.timesteps,
+ sampling_timesteps=args.sampling_timesteps,
+ loss_type=model_cfg.loss_type,
+ objective=model_cfg.objective,
+ scale_factor=model_cfg.scale_factor,
+ scale_by_std=model_cfg.scale_by_std,
+ scale_by_softsign=model_cfg.scale_by_softsign,
+ default_scale=model_cfg.get('default_scale', False),
+ input_keys=model_cfg.input_keys,
+ ckpt_path=model_cfg.ckpt_path,
+ ignore_keys=model_cfg.ignore_keys,
+ only_model=model_cfg.only_model,
+ start_dist=model_cfg.start_dist,
+ perceptual_weight=model_cfg.perceptual_weight,
+ use_l1=model_cfg.get('use_l1', True),
+ cfg=model_cfg,
+ )
+ self.cfg.sampler.ckpt_path = args.pre_weight
+
+ data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu")
+ if self.cfg.sampler.use_ema:
+ sd = data['ema']
+ new_sd = {}
+ for k in sd.keys():
+ if k.startswith("ema_model."):
+ new_k = k[10:] # remove ema_model.
+ new_sd[new_k] = sd[k]
+ sd = new_sd
+ self.model.load_state_dict(sd)
+ else:
+ self.model.load_state_dict(data['model'])
+ if 'scale_factor' in data['model']:
+ self.model.scale_factor = data['model']['scale_factor']
+
+ self.model.eval()
+ self.device = "cpu"
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, image, batch_size=8):
+ image = normalize_to_neg_one_to_one(image).to(self.device)
+ mask = None
+ if self.cfg.sampler.sample_type == 'whole':
+ return self.whole_sample(image, raw_size=image.shape[2:], mask=mask)
+ elif self.cfg.sampler.sample_type == 'slide':
+ return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]),
+ stride=self.cfg.sampler.stride, mask=mask, bs=batch_size)
+
+ def whole_sample(self, inputs, raw_size, mask=None):
+ inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True)
+ seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask)
+ seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True)
+ return seg_logits
+
+ def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+
+ Args:
+ inputs (tensor): the tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', and 'pad_shape'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The segmentation results, seg_logits from model of each
+ input image.
+ """
+
+ h_stride, w_stride = stride
+ h_crop, w_crop = crop_size
+ batch_size, _, h_img, w_img = inputs.size()
+ out_channels = 1
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ # aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ # aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
+ crop_imgs = []
+ x1s = []
+ x2s = []
+ y1s = []
+ y2s = []
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = inputs[:, :, y1:y2, x1:x2]
+ crop_imgs.append(crop_img)
+ x1s.append(x1)
+ x2s.append(x2)
+ y1s.append(y1)
+ y2s.append(y2)
+ crop_imgs = torch.cat(crop_imgs, dim=0)
+ crop_seg_logits_list = []
+ num_windows = crop_imgs.shape[0]
+ bs = bs
+ length = math.ceil(num_windows / bs)
+ for i in range(length):
+ if i == length - 1:
+ crop_imgs_temp = crop_imgs[bs * i:num_windows, ...]
+ else:
+ crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...]
+
+ crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask)
+ crop_seg_logits_list.append(crop_seg_logits)
+ crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0)
+ for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s):
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+ count_mat[:, :, y1:y2, x1:x2] += 1
+
+ assert (count_mat == 0).sum() == 0
+ seg_logits = preds / count_mat
+ return seg_logits
diff --git a/src/custom_controlnet_aux/diffusion_edge/requirement.txt b/src/custom_controlnet_aux/diffusion_edge/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..333c5e3dd6f2662a9e887bc015895949eaf8d126
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/requirement.txt
@@ -0,0 +1,9 @@
+#torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
+einops
+scikit-learn
+scipy
+tensorboard
+fvcore
+albumentations
+omegaconf
+numpy==1.23.5
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/__init__.py b/src/custom_controlnet_aux/diffusion_edge/taming/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..53586e54a348101f40bb26c4f1e02b884f7f131c
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+from custom_controlnet_aux.diffusion_edge.taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/ade20k_examples.txt",
+ data_root="data/ade20k_images",
+ segmentation_root="data/ade20k_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=151, shift_segmentation=False)
+
+
+# With semantic map and scene label
+class ADE20kBase(Dataset):
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
+ self.split = self.get_split()
+ self.n_labels = 151 # unknown + 150
+ self.data_csv = {"train": "data/ade20k_train.txt",
+ "validation": "data/ade20k_test.txt"}[self.split]
+ self.data_root = "data/ade20k_root"
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
+ self.scene_categories = f.read().splitlines()
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, "images", l)
+ for l in self.image_paths],
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
+ l.replace(".jpg", ".png"))
+ for l in self.image_paths],
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
+ for l in self.image_paths],
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size if size is not None else None
+ else:
+ self.crop_size = crop_size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+
+ if crop_size is not None:
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image, mask=segmentation)
+ else:
+ processed = {"image": image, "mask": segmentation}
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class ADE20kTrain(ADE20kBase):
+ # default to random_crop=True
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ interpolation=interpolation, crop_size=crop_size)
+
+ def get_split(self):
+ return "train"
+
+
+class ADE20kValidation(ADE20kBase):
+ def get_split(self):
+ return "validation"
+
+
+if __name__ == "__main__":
+ dset = ADE20kValidation()
+ ex = dset[0]
+ for k in ["image", "scene_category", "segmentation"]:
+ print(type(ex[k]))
+ try:
+ print(ex[k].shape)
+ except:
+ print(ex[k])
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb27bbfb56a9cb9e490f447e93d0d28fe351721b
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py
@@ -0,0 +1,139 @@
+import json
+from itertools import chain
+from pathlib import Path
+from typing import Iterable, Dict, List, Callable, Any
+from collections import defaultdict
+
+from tqdm import tqdm
+
+from custom_controlnet_aux.diffusion_edge.taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation, ImageDescription, Category
+
+COCO_PATH_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_train2017.json',
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
+ 'files': 'train2017'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_val2017.json',
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
+ 'files': 'val2017'
+ }
+}
+
+
+def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
+ return {
+ str(img['id']): ImageDescription(
+ id=img['id'],
+ license=img.get('license'),
+ file_name=img['file_name'],
+ coco_url=img['coco_url'],
+ original_size=(img['width'], img['height']),
+ date_captured=img.get('date_captured'),
+ flickr_url=img.get('flickr_url')
+ )
+ for img in description_json
+ }
+
+
+def load_categories(category_json: Iterable) -> Dict[str, Category]:
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
+ for cat in category_json if cat['name'] != 'other'}
+
+
+def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
+ annotations = defaultdict(list)
+ total = sum(len(a) for a in annotations_json)
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
+ image_id = str(ann['image_id'])
+ if image_id not in image_descriptions:
+ raise ValueError(f'image_id [{image_id}] has no image description.')
+ category_id = ann['category_id']
+ try:
+ category_no = category_no_for_id(str(category_id))
+ except KeyError:
+ continue
+
+ width, height = image_descriptions[image_id].original_size
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
+
+ annotations[image_id].append(
+ Annotation(
+ id=ann['id'],
+ area=bbox[2]*bbox[3], # use bbox area
+ is_group_of=ann['iscrowd'],
+ image_id=ann['image_id'],
+ bbox=bbox,
+ category_id=str(category_id),
+ category_no=category_no
+ )
+ )
+ return dict(annotations)
+
+
+class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ coco/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ │ ├── stuff_train2017.json
+ │ └── stuff_val2017.json
+ ├── train2017
+ │ ├── 000000000009.jpg
+ │ ├── 000000000025.jpg
+ │ └── ...
+ ├── val2017
+ │ ├── 000000000139.jpg
+ │ ├── 000000000285.jpg
+ │ └── ...
+ @param: split: one of 'train' or 'validation'
+ @param: desired image size (give square images)
+ """
+ super().__init__(**kwargs)
+ self.use_things = use_things
+ self.use_stuff = use_stuff
+
+ with open(self.paths['instances_annotations']) as f:
+ inst_data_json = json.load(f)
+ with open(self.paths['stuff_annotations']) as f:
+ stuff_data_json = json.load(f)
+
+ category_jsons = []
+ annotation_jsons = []
+ if self.use_things:
+ category_jsons.append(inst_data_json['categories'])
+ annotation_jsons.append(inst_data_json['annotations'])
+ if self.use_stuff:
+ category_jsons.append(stuff_data_json['categories'])
+ annotation_jsons.append(stuff_data_json['annotations'])
+
+ self.categories = load_categories(chain(*category_jsons))
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
+ self.min_objects_per_image, self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in COCO_PATH_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
+ return COCO_PATH_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ # noinspection PyProtectedMember
+ return self.image_descriptions[image_id]._asdict()
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..af3af83a56abafa1c91c610b34c4ff9e333bcfd4
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py
@@ -0,0 +1,218 @@
+from pathlib import Path
+from typing import Optional, List, Callable, Dict, Any, Union
+import warnings
+
+import PIL.Image as pil_image
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import load_object_from_string
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
+from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
+
+
+class AnnotatedObjectsDataset(Dataset):
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
+ no_object_classes: Optional[int] = None):
+ self.data_path = data_path
+ self.split = split
+ self.keys = keys
+ self.target_image_size = target_image_size
+ self.min_object_area = min_object_area
+ self.min_objects_per_image = min_objects_per_image
+ self.max_objects_per_image = max_objects_per_image
+ self.crop_method = crop_method
+ self.random_flip = random_flip
+ self.no_tokens = no_tokens
+ self.use_group_parameter = use_group_parameter
+ self.encode_crop = encode_crop
+
+ self.annotations = None
+ self.image_descriptions = None
+ self.categories = None
+ self.category_ids = None
+ self.category_number = None
+ self.image_ids = None
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
+ self.paths = self.build_paths(self.data_path)
+ self._conditional_builders = None
+ self.category_allow_list = None
+ if category_allow_list_target:
+ allow_list = load_object_from_string(category_allow_list_target)
+ self.category_allow_list = {name for name, _ in allow_list}
+ self.category_mapping = {}
+ if category_mapping_target:
+ self.category_mapping = load_object_from_string(category_mapping_target)
+ self.no_object_classes = no_object_classes
+
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
+ top_level = Path(top_level)
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
+ for path in sub_paths.values():
+ if not path.exists():
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
+ return sub_paths
+
+ @staticmethod
+ def load_image_from_disk(path: Path) -> Image:
+ return pil_image.open(path).convert('RGB')
+
+ @staticmethod
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
+ transform_functions = []
+ if crop_method == 'none':
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
+ elif crop_method == 'center':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ CenterCropReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-1d':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ RandomCrop1dReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-2d':
+ transform_functions.extend([
+ Random2dCropReturnCoordinates(target_image_size),
+ transforms.Resize(target_image_size)
+ ])
+ elif crop_method is None:
+ return None
+ else:
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
+ if random_flip:
+ transform_functions.append(RandomHorizontalFlipReturn())
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
+ return transform_functions
+
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
+ crop_bbox = None
+ flipped = None
+ for t in self.transform_functions:
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
+ crop_bbox, x = t(x)
+ elif isinstance(t, RandomHorizontalFlipReturn):
+ flipped, x = t(x)
+ else:
+ x = t(x)
+ return crop_bbox, flipped, x
+
+ @property
+ def no_classes(self) -> int:
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
+
+ @property
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
+ if self._conditional_builders is None:
+ self._conditional_builders = {
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ ),
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ )
+ }
+ return self._conditional_builders
+
+ def filter_categories(self) -> None:
+ if self.category_allow_list:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
+ if self.category_mapping:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
+
+ def setup_category_id_and_number(self) -> None:
+ self.category_ids = list(self.categories.keys())
+ self.category_ids.sort()
+ if '/m/01s55n' in self.category_ids:
+ self.category_ids.remove('/m/01s55n')
+ self.category_ids.append('/m/01s55n')
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
+ if self.category_allow_list is not None and self.category_mapping is None \
+ and len(self.category_ids) != len(self.category_allow_list):
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
+ 'Make sure all names in category_allow_list exist.')
+
+ def clean_up_annotations_and_image_descriptions(self) -> None:
+ image_id_set = set(self.image_ids)
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
+
+ @staticmethod
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
+ filtered = {}
+ for image_id, annotations in all_annotations.items():
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
+ filtered[image_id] = annotations_with_min_area
+ return filtered
+
+ def __len__(self):
+ return len(self.image_ids)
+
+ def __getitem__(self, n: int) -> Dict[str, Any]:
+ image_id = self.get_image_id(n)
+ sample = self.get_image_description(image_id)
+ sample['annotations'] = self.get_annotation(image_id)
+
+ if 'image' in self.keys:
+ sample['image_path'] = str(self.get_image_path(image_id))
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
+ sample['image'] = convert_pil_to_tensor(sample['image'])
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
+ sample['image'] = sample['image'].permute(1, 2, 0)
+
+ for conditional, builder in self.conditional_builders.items():
+ if conditional in self.keys:
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
+
+ if self.keys:
+ # only return specified keys
+ sample = {key: sample[key] for key in self.keys}
+ return sample
+
+ def get_image_id(self, no: int) -> str:
+ return self.image_ids[no]
+
+ def get_annotation(self, image_id: str) -> str:
+ return self.annotations[image_id]
+
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
+ return self.categories[category_id].name
+
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
+ return self.categories[self.get_category_id(category_no)].name
+
+ def get_category_number(self, category_id: str) -> int:
+ return self.category_number[category_id]
+
+ def get_category_id(self, category_no: int) -> str:
+ return self.category_ids[category_no]
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+ def get_path_structure(self):
+ raise NotImplementedError
+
+ def get_image_path(self, image_id: str) -> Path:
+ raise NotImplementedError
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc5c8185496ec1be6f16388dee8217f4cb686f3a
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py
@@ -0,0 +1,137 @@
+from collections import defaultdict
+from csv import DictReader, reader as TupleReader
+from pathlib import Path
+from typing import Dict, List, Any
+import warnings
+
+from custom_controlnet_aux.diffusion_edge.taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation, Category
+from tqdm import tqdm
+
+OPEN_IMAGES_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
+ 'file_list': 'train-images-boxable.csv',
+ 'files': 'train'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'validation-annotations-bbox.csv',
+ 'file_list': 'validation-images.csv',
+ 'files': 'validation'
+ },
+ 'test': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'test-annotations-bbox.csv',
+ 'file_list': 'test-images.csv',
+ 'files': 'test'
+ }
+}
+
+
+def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
+ with open(descriptor_path) as file:
+ reader = DictReader(file)
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
+ width = float(row['XMax']) - float(row['XMin'])
+ height = float(row['YMax']) - float(row['YMin'])
+ area = width * height
+ category_id = row['LabelName']
+ if category_id in category_mapping:
+ category_id = category_mapping[category_id]
+ if area >= min_object_area and category_id in category_no_for_id:
+ annotations[row['ImageID']].append(
+ Annotation(
+ id=i,
+ image_id=row['ImageID'],
+ source=row['Source'],
+ category_id=category_id,
+ category_no=category_no_for_id[category_id],
+ confidence=float(row['Confidence']),
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
+ area=area,
+ is_occluded=bool(int(row['IsOccluded'])),
+ is_truncated=bool(int(row['IsTruncated'])),
+ is_group_of=bool(int(row['IsGroupOf'])),
+ is_depiction=bool(int(row['IsDepiction'])),
+ is_inside=bool(int(row['IsInside']))
+ )
+ )
+ if 'train' in str(descriptor_path) and i < 14000000:
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
+ return dict(annotations)
+
+
+def load_image_ids(csv_path: Path) -> List[str]:
+ with open(csv_path) as file:
+ reader = DictReader(file)
+ return [row['image_name'] for row in reader]
+
+
+def load_categories(csv_path: Path) -> Dict[str, Category]:
+ with open(csv_path) as file:
+ reader = TupleReader(file)
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
+
+
+class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
+ def __init__(self, use_additional_parameters: bool, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ open_images/
+ │ oidv6-train-annotations-bbox.csv
+ ├── class-descriptions-boxable.csv
+ ├── oidv6-train-annotations-bbox.csv
+ ├── test
+ │ ├── 000026e7ee790996.jpg
+ │ ├── 000062a39995e348.jpg
+ │ └── ...
+ ├── test-annotations-bbox.csv
+ ├── test-images.csv
+ ├── train
+ │ ├── 000002b66c9c498e.jpg
+ │ ├── 000002b97e5471a0.jpg
+ │ └── ...
+ ├── train-images-boxable.csv
+ ├── validation
+ │ ├── 0001eeaf4aed83f9.jpg
+ │ ├── 0004886b7d043cfd.jpg
+ │ └── ...
+ ├── validation-annotations-bbox.csv
+ └── validation-images.csv
+ @param: split: one of 'train', 'validation' or 'test'
+ @param: desired image size (returns square images)
+ """
+
+ super().__init__(**kwargs)
+ self.use_additional_parameters = use_additional_parameters
+
+ self.categories = load_categories(self.paths['class_descriptions'])
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = {}
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
+ self.category_number)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
+ self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in OPEN_IMAGES_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
+ return OPEN_IMAGES_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ image_path = self.get_image_path(image_id)
+ return {'file_path': str(image_path), 'file_name': image_path.name}
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/base.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/base.py
@@ -0,0 +1,70 @@
+import bisect
+import numpy as np
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset, ConcatDataset
+
+
+class ConcatDatasetWithIndex(ConcatDataset):
+ """Modified from original pytorch code to return dataset idx"""
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
+
+
+class ImagePaths(Dataset):
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
+ self.size = size
+ self.random_crop = random_crop
+
+ self.labels = dict() if labels is None else labels
+ self.labels["file_path_"] = paths
+ self._length = len(paths)
+
+ if self.size is not None and self.size > 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ if not self.random_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return self._length
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def __getitem__(self, i):
+ example = dict()
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
+ for k in self.labels:
+ example[k] = self.labels[k][i]
+ return example
+
+
+class NumpyPaths(ImagePaths):
+ def preprocess_image(self, image_path):
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
+ image = np.transpose(image, (1,2,0))
+ image = Image.fromarray(image, mode="RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/coco.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e2da222563c7f2085a2e40cd47699b049e905a
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/coco.py
@@ -0,0 +1,176 @@
+import os
+import json
+import albumentations
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+
+from custom_controlnet_aux.diffusion_edge.taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/coco_examples.txt",
+ data_root="data/coco_images",
+ segmentation_root="data/coco_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=183, shift_segmentation=True)
+
+
+class CocoBase(Dataset):
+ """needed for (image, caption, segmentation) pairs"""
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
+ crop_size=None, force_no_crop=False, given_files=None):
+ self.split = self.get_split()
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size
+ else:
+ self.crop_size = crop_size
+
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
+ self.stuffthing = use_stuffthing # include thing in segmentation
+ if self.onehot and not self.stuffthing:
+ raise NotImplemented("One hot mode is only supported for the "
+ "stuffthings version because labels are stored "
+ "a bit different.")
+
+ data_json = datajson
+ with open(data_json) as json_file:
+ self.json_data = json.load(json_file)
+ self.img_id_to_captions = dict()
+ self.img_id_to_filepath = dict()
+ self.img_id_to_segmentation_filepath = dict()
+
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
+ "captions_val2017.json"]
+ if self.stuffthing:
+ self.segmentation_prefix = (
+ "data/cocostuffthings/val2017" if
+ data_json.endswith("captions_val2017.json") else
+ "data/cocostuffthings/train2017")
+ else:
+ self.segmentation_prefix = (
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
+ data_json.endswith("captions_val2017.json") else
+ "data/coco/annotations/stuff_train2017_pixelmaps")
+
+ imagedirs = self.json_data["images"]
+ self.labels = {"image_ids": list()}
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
+ self.img_id_to_captions[imgdir["id"]] = list()
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
+ self.segmentation_prefix, pngfilename)
+ if given_files is not None:
+ if pngfilename in given_files:
+ self.labels["image_ids"].append(imgdir["id"])
+ else:
+ self.labels["image_ids"].append(imgdir["id"])
+
+ capdirs = self.json_data["annotations"]
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
+ # there are in average 5 captions per image
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
+
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
+ if self.split=="validation":
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler, self.cropper],
+ additional_targets={"segmentation": "image"})
+ if force_no_crop:
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler],
+ additional_targets={"segmentation": "image"})
+
+ def __len__(self):
+ return len(self.labels["image_ids"])
+
+ def preprocess_image(self, image_path, segmentation_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ segmentation = Image.open(segmentation_path)
+ if not self.onehot and not segmentation.mode == "RGB":
+ segmentation = segmentation.convert("RGB")
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.onehot:
+ assert self.stuffthing
+ # stored in caffe format: unlabeled==255. stuff and thing from
+ # 0-181. to be compatible with the labels in
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
+ # we shift stuffthing one to the right and put unlabeled in zero
+ # as long as segmentation is uint8 shifting to right handles the
+ # latter too
+ assert segmentation.dtype == np.uint8
+ segmentation = segmentation + 1
+
+ processed = self.preprocessor(image=image, segmentation=segmentation)
+ image, segmentation = processed["image"], processed["segmentation"]
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ if self.onehot:
+ assert segmentation.dtype == np.uint8
+ # make it one hot
+ n_labels = 183
+ flatseg = np.ravel(segmentation)
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
+ onehot[np.arange(flatseg.size), flatseg] = True
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
+ segmentation = onehot
+ else:
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
+ return image, segmentation
+
+ def __getitem__(self, i):
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
+ image, segmentation = self.preprocess_image(img_path, seg_path)
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
+ # randomly draw one of all available captions per image
+ caption = captions[np.random.randint(0, len(captions))]
+ example = {"image": image,
+ "caption": [str(caption[0])],
+ "segmentation": segmentation,
+ "img_path": img_path,
+ "seg_path": seg_path,
+ "filename_": img_path.split(os.sep)[-1]
+ }
+ return example
+
+
+class CocoImagesAndCaptionsTrain(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
+ super().__init__(size=size,
+ dataroot="data/coco/train2017",
+ datajson="data/coco/annotations/captions_train2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
+
+ def get_split(self):
+ return "train"
+
+
+class CocoImagesAndCaptionsValidation(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
+ given_files=None):
+ super().__init__(size=size,
+ dataroot="data/coco/val2017",
+ datajson="data/coco/annotations/captions_val2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
+ given_files=given_files)
+
+ def get_split(self):
+ return "validation"
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fa4867952d60cdd84060dd186b126cf7eefbe2
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py
@@ -0,0 +1,60 @@
+from itertools import cycle
+from typing import List, Tuple, Callable, Optional
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
+ pad_list, get_plot_font_size, absolute_bbox
+
+
+class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
+ @property
+ def object_descriptor_length(self) -> int:
+ return 3
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_triples = [
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
+ for ann in annotations
+ ]
+ empty_triple = (self.none, self.none, self.none)
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
+ return object_triples
+
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ object_triples = grouper(conditional_list, 3)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
+ for object_triple in object_triples if object_triple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ font = ImageFont.truetype(
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
+ size=get_plot_font_size(font_size, figure_size)
+ )
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
+ annotation = self.representation_to_annotation(representation)
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
+ bbox = absolute_bbox(bbox, width, height)
+ draw.rectangle(bbox, outline=color, width=line_width)
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd1989c6647f08384d69b3cad7fe57288776d0b
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py
@@ -0,0 +1,168 @@
+import math
+import random
+import warnings
+from itertools import cycle
+from typing import List, Optional, Tuple, Callable
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
+ absolute_bbox, rescale_annotations
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation
+from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+
+class ObjectsCenterPointsConditionalBuilder:
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
+ use_group_parameter: bool, use_additional_parameters: bool):
+ self.no_object_classes = no_object_classes
+ self.no_max_objects = no_max_objects
+ self.no_tokens = no_tokens
+ self.encode_crop = encode_crop
+ self.no_sections = int(math.sqrt(self.no_tokens))
+ self.use_group_parameter = use_group_parameter
+ self.use_additional_parameters = use_additional_parameters
+
+ @property
+ def none(self) -> int:
+ return self.no_tokens - 1
+
+ @property
+ def object_descriptor_length(self) -> int:
+ return 2
+
+ @property
+ def embedding_dim(self) -> int:
+ extra_length = 2 if self.encode_crop else 0
+ return self.no_max_objects * self.object_descriptor_length + extra_length
+
+ def tokenize_coordinates(self, x: float, y: float) -> int:
+ """
+ Express 2d coordinates with one number.
+ Example: assume self.no_tokens = 16, then no_sections = 4:
+ 0 0 0 0
+ 0 0 # 0
+ 0 0 0 0
+ 0 0 0 x
+ Then the # position corresponds to token 6, the x position to token 15.
+ @param x: float in [0, 1]
+ @param y: float in [0, 1]
+ @return: discrete tokenized coordinate
+ """
+ x_discrete = int(round(x * (self.no_sections - 1)))
+ y_discrete = int(round(y * (self.no_sections - 1)))
+ return y_discrete * self.no_sections + x_discrete
+
+ def coordinates_from_token(self, token: int) -> (float, float):
+ x = token % self.no_sections
+ y = token // self.no_sections
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
+
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
+ x0, y0 = self.coordinates_from_token(token1)
+ x1, y1 = self.coordinates_from_token(token2)
+ return x0, y0, x1 - x0, y1 - y0
+
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+ def inverse_build(self, conditional: LongTensor) \
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
+ for object_tuple in table_of_content if object_tuple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ circle_size = get_circle_size(figure_size)
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
+ size=get_plot_font_size(font_size, figure_size))
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
+ x_abs, y_abs = x * width, y * height
+ ann = self.representation_to_annotation(representation)
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
+
+ def object_representation(self, annotation: Annotation) -> int:
+ modifier = 0
+ if self.use_group_parameter:
+ modifier |= 1 * (annotation.is_group_of is True)
+ if self.use_additional_parameters:
+ modifier |= 2 * (annotation.is_occluded is True)
+ modifier |= 4 * (annotation.is_depiction is True)
+ modifier |= 8 * (annotation.is_inside is True)
+ return annotation.category_no + self.no_object_classes * modifier
+
+ def representation_to_annotation(self, representation: int) -> Annotation:
+ category_no = representation % self.no_object_classes
+ modifier = representation // self.no_object_classes
+ # noinspection PyTypeChecker
+ return Annotation(
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
+ category_no=category_no,
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
+ )
+
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
+ return list(self.token_pair_from_bbox(crop_coordinates))
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_tuples = [
+ (self.object_representation(a),
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
+ for a in annotations
+ ]
+ empty_tuple = (self.none, self.none)
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
+ return object_tuples
+
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
+ -> LongTensor:
+ if len(annotations) == 0:
+ warnings.warn('Did not receive any annotations.')
+ if len(annotations) > self.no_max_objects:
+ warnings.warn('Received more annotations than allowed.')
+ annotations = annotations[:self.no_max_objects]
+
+ if not crop_coordinates:
+ crop_coordinates = FULL_CROP
+
+ random.shuffle(annotations)
+ annotations = filter_annotations(annotations, crop_coordinates)
+ if self.encode_crop:
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
+ if horizontal_flip:
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
+ extra = self._crop_encoder(crop_coordinates)
+ else:
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
+ extra = []
+
+ object_tuples = self._make_object_descriptors(annotations)
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
+ assert len(flattened) == self.embedding_dim
+ assert all(0 <= value < self.no_tokens for value in flattened)
+ return LongTensor(flattened)
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b97140e32b3d0772247ddb7121c22ce56cc90c4f
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py
@@ -0,0 +1,105 @@
+import importlib
+from typing import List, Any, Tuple, Optional
+
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation
+
+# source: seaborn, color palette tab10
+COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
+BLACK = (0, 0, 0)
+GRAY_75 = (63, 63, 63)
+GRAY_50 = (127, 127, 127)
+GRAY_25 = (191, 191, 191)
+WHITE = (255, 255, 255)
+FULL_CROP = (0., 0., 1., 1.)
+
+
+def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
+ """
+ Give intersection area of two rectangles.
+ @param rectangle1: (x0, y0, w, h) of first rectangle
+ @param rectangle2: (x0, y0, w, h) of second rectangle
+ """
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
+ return x_overlap * y_overlap
+
+
+def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
+
+
+def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
+ bbox = relative_bbox
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+
+def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
+
+
+def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
+ List[Annotation]:
+ def clamp(x: float):
+ return max(min(x, 1.), 0.)
+
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ if flip:
+ x0 = 1 - (x0 + w)
+ return x0, y0, w, h
+
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
+
+
+def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
+
+
+def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
+ sl = slice(1) if short else slice(None)
+ string = ''
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
+ return string
+ if annotation.is_group_of:
+ string += 'group'[sl] + ','
+ if annotation.is_occluded:
+ string += 'occluded'[sl] + ','
+ if annotation.is_depiction:
+ string += 'depiction'[sl] + ','
+ if annotation.is_inside:
+ string += 'inside'[sl]
+ return '(' + string.strip(",") + ')'
+
+
+def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
+ if font_size is None:
+ font_size = 10
+ if max(figure_size) >= 256:
+ font_size = 12
+ if max(figure_size) >= 512:
+ font_size = 15
+ return font_size
+
+
+def get_circle_size(figure_size: Tuple[int, int]) -> int:
+ circle_size = 2
+ if max(figure_size) >= 256:
+ circle_size = 3
+ if max(figure_size) >= 512:
+ circle_size = 4
+ return circle_size
+
+
+def load_object_from_string(object_string: str) -> Any:
+ """
+ Source: https://stackoverflow.com/a/10773699
+ """
+ module_name, class_name = object_string.rsplit(".", 1)
+ return getattr(importlib.import_module(module_name), class_name)
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/custom.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9224014b80840694acfa5efe8b0d23274fb09d1
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/custom.py
@@ -0,0 +1,38 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class CustomBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ return example
+
+
+
+class CustomTrain(CustomBase):
+ def __init__(self, size, training_images_list_file):
+ super().__init__()
+ with open(training_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
+class CustomTest(CustomBase):
+ def __init__(self, size, test_images_list_file):
+ super().__init__()
+ with open(test_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d292b2198fe63745e271adc8d4f2720b393b668
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class FacesBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+ self.keys = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ ex = {}
+ if self.keys is not None:
+ for k in self.keys:
+ ex[k] = example[k]
+ else:
+ ex = example
+ return ex
+
+
+class CelebAHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class CelebAHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FacesHQTrain(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQTrain(size=size, keys=keys)
+ d2 = FFHQTrain(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
+
+
+class FacesHQValidation(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQValidation(size=size, keys=keys)
+ d2 = FFHQValidation(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py
@@ -0,0 +1,49 @@
+from typing import Dict, Tuple, Optional, NamedTuple, Union
+from PIL.Image import Image as pil_image
+from torch import Tensor
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+Image = Union[Tensor, pil_image]
+BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
+CropMethodType = Literal['none', 'random', 'center', 'random-2d']
+SplitType = Literal['train', 'validation', 'test']
+
+
+class ImageDescription(NamedTuple):
+ id: int
+ file_name: str
+ original_size: Tuple[int, int] # w, h
+ url: Optional[str] = None
+ license: Optional[int] = None
+ coco_url: Optional[str] = None
+ date_captured: Optional[str] = None
+ flickr_url: Optional[str] = None
+ flickr_id: Optional[str] = None
+ coco_id: Optional[str] = None
+
+
+class Category(NamedTuple):
+ id: str
+ super_category: Optional[str]
+ name: str
+
+
+class Annotation(NamedTuple):
+ area: float
+ image_id: str
+ bbox: BoundingBox
+ category_no: int
+ category_id: str
+ id: Optional[int] = None
+ source: Optional[str] = None
+ confidence: Optional[float] = None
+ is_group_of: Optional[bool] = None
+ is_truncated: Optional[bool] = None
+ is_occluded: Optional[bool] = None
+ is_depiction: Optional[bool] = None
+ is_inside: Optional[bool] = None
+ segmentation: Optional[Dict] = None
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1289f06c7df1a4806eaa2c58fc02e5bdcbc62f8
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py
@@ -0,0 +1,132 @@
+import random
+import warnings
+from typing import Union
+
+import torch
+from torch import Tensor
+from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
+from torchvision.transforms.functional import _get_image_size as get_image_size
+
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Image
+
+pil_to_tensor = PILToTensor()
+
+
+def convert_pil_to_tensor(image: Image) -> Tensor:
+ with warnings.catch_warnings():
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
+ warnings.simplefilter("ignore")
+ return pil_to_tensor(image)
+
+
+class RandomCrop1dReturnCoordinates(RandomCrop):
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ width, height = get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
+ return bbox, F.crop(img, i, j, h, w)
+
+
+class Random2dCropReturnCoordinates(torch.nn.Module):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+
+ def __init__(self, min_size: int):
+ super().__init__()
+ self.min_size = min_size
+
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ width, height = get_image_size(img)
+ max_size = min(width, height)
+ if max_size <= self.min_size:
+ size = max_size
+ else:
+ size = random.randint(self.min_size, max_size)
+ top = random.randint(0, height - size)
+ left = random.randint(0, width - size)
+ bbox = left / width, top / height, size / width, size / height
+ return bbox, F.crop(img, top, left, size, size)
+
+
+class CenterCropReturnCoordinates(CenterCrop):
+ @staticmethod
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
+ if width > height:
+ w = height / width
+ h = 1.0
+ x0 = 0.5 - w / 2
+ y0 = 0.
+ else:
+ w = 1.0
+ h = width / height
+ x0 = 0.
+ y0 = 0.5 - h / 2
+ return x0, y0, w, h
+
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ width, height = get_image_size(img)
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
+
+
+class RandomHorizontalFlipReturn(RandomHorizontalFlip):
+ def forward(self, img: Image) -> (bool, Image):
+ """
+ Additionally to flipping, returns a boolean whether it was flipped or not.
+ Args:
+ img (PIL Image or Tensor): Image to be flipped.
+
+ Returns:
+ flipped: whether the image was flipped or not
+ PIL Image or Tensor: Randomly flipped image.
+
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ if torch.rand(1) < self.p:
+ return True, F.hflip(img)
+ return False, img
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe485d3f81dffe5a40e1ab1343207fbdd11f5cc
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py
@@ -0,0 +1,558 @@
+import os, tarfile, glob, shutil
+import yaml
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import albumentations
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset
+
+from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths
+from custom_controlnet_aux.diffusion_edge.taming.util import download, retrieve
+import taming.data.utils as bdu
+
+
+def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
+ synsets = []
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ for idx in indices:
+ synsets.append(str(di2s[idx]))
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
+ return synsets
+
+
+def str_to_indices(string):
+ """Expects a string in the format '32-123, 256, 280-321'"""
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
+ subs = string.split(",")
+ indices = []
+ for sub in subs:
+ subsubs = sub.split("-")
+ assert len(subsubs) > 0
+ if len(subsubs) == 1:
+ indices.append(int(subsubs[0]))
+ else:
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
+ indices.extend(rang)
+ return sorted(indices)
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ self.class_labels = [class_dict[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=retrieve(self.config, "size", default=0),
+ random_crop=self.random_crop)
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+def get_preprocessor(size=None, random_crop=False, additional_targets=None,
+ crop_size=None):
+ if size is not None and size > 0:
+ transforms = list()
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
+ transforms.append(rescaler)
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=size,width=size)
+ transforms.append(cropper)
+ else:
+ cropper = albumentations.RandomCrop(height=size,width=size)
+ transforms.append(cropper)
+ flipper = albumentations.HorizontalFlip()
+ transforms.append(flipper)
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ elif crop_size is not None and crop_size > 0:
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ transforms = [cropper]
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ preprocessor = lambda **kwargs: kwargs
+ return preprocessor
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+class BaseWithDepth(Dataset):
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
+
+ def __init__(self, config=None, size=None, random_crop=False,
+ crop_size=None, root=None):
+ self.config = config
+ self.base_dset = self.get_base_dset()
+ self.preprocessor = get_preprocessor(
+ size=size,
+ crop_size=crop_size,
+ random_crop=random_crop,
+ additional_targets={"depth": "image"})
+ self.crop_size = crop_size
+ if self.crop_size is not None:
+ self.rescaler = albumentations.Compose(
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
+ additional_targets={"depth": "image"})
+ if root is not None:
+ self.DEFAULT_DEPTH_ROOT = root
+
+ def __len__(self):
+ return len(self.base_dset)
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = self.base_dset[i]
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
+ # up if necessary
+ h,w,c = e["image"].shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ out = self.rescaler(image=e["image"], depth=e["depth"])
+ e["image"] = out["image"]
+ e["depth"] = out["depth"]
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+class ImageNetTrainWithDepth(BaseWithDepth):
+ # default to random_crop=True
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetTrain()
+ else:
+ return ImageNetTrain({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
+ return fid
+
+
+class ImageNetValidationWithDepth(BaseWithDepth):
+ def __init__(self, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(**kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetValidation()
+ else:
+ return ImageNetValidation({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
+ return fid
+
+
+class RINTrainWithDepth(ImageNetTrainWithDepth):
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class RINValidationWithDepth(ImageNetValidationWithDepth):
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class DRINExamples(Dataset):
+ def __init__(self):
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
+ with open("data/drin_examples.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ self.image_paths = [os.path.join("data/drin_images",
+ relpath) for relpath in relpaths]
+ self.depth_paths = [os.path.join("data/drin_depth",
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = dict()
+ e["image"] = self.preprocess_image(self.image_paths[i])
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
+ if factor is None or factor==1:
+ return x
+
+ dtype = x.dtype
+ assert dtype in [np.float32, np.float64]
+ assert x.min() >= -1
+ assert x.max() <= 1
+
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC}[keepmode]
+
+ lr = (x+1.0)*127.5
+ lr = lr.clip(0,255).astype(np.uint8)
+ lr = Image.fromarray(lr)
+
+ h, w, _ = x.shape
+ nh = h//factor
+ nw = w//factor
+ assert nh > 0 and nw > 0, (nh, nw)
+
+ lr = lr.resize((nw,nh), Image.BICUBIC)
+ if keepshapes:
+ lr = lr.resize((w,h), keepmode)
+ lr = np.array(lr)/127.5-1.0
+ lr = lr.astype(dtype)
+
+ return lr
+
+
+class ImageNetScale(Dataset):
+ def __init__(self, size=None, crop_size=None, random_crop=False,
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
+ self.base = self.get_base()
+
+ self.size = size
+ self.crop_size = crop_size if crop_size is not None else self.size
+ self.random_crop = random_crop
+ self.up_factor = up_factor
+ self.hr_factor = hr_factor
+ self.keep_mode = keep_mode
+
+ transforms = list()
+
+ if self.size is not None and self.size > 0:
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ self.rescaler = rescaler
+ transforms.append(rescaler)
+
+ if self.crop_size is not None and self.crop_size > 0:
+ if len(transforms) == 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
+
+ if not self.random_crop:
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
+ transforms.append(cropper)
+
+ if len(transforms) > 0:
+ if self.up_factor is not None:
+ additional_targets = {"lr": "image"}
+ else:
+ additional_targets = None
+ self.preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ # adjust resolution
+ image = imscale(image, self.hr_factor, keepshapes=False)
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+ if self.up_factor is None:
+ image = self.preprocessor(image=image)["image"]
+ example["image"] = image
+ else:
+ lr = imscale(image, self.up_factor, keepshapes=True,
+ keepmode=self.keep_mode)
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+class ImageNetScaleTrain(ImageNetScale):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetScaleValidation(ImageNetScale):
+ def get_base(self):
+ return ImageNetValidation()
+
+
+from skimage.feature import canny
+from skimage.color import rgb2gray
+
+
+class ImageNetEdges(ImageNetScale):
+ def __init__(self, up_factor=1, **kwargs):
+ super().__init__(up_factor=1, **kwargs)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+
+ lr = canny(rgb2gray(image), sigma=2)
+ lr = lr.astype(np.float32)
+ lr = lr[:,:,None][:,:,[0,0,0]]
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+
+class ImageNetEdgesTrain(ImageNetEdges):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetEdgesValidation(ImageNetEdges):
+ def get_base(self):
+ return ImageNetValidation()
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py
@@ -0,0 +1,379 @@
+open_images_unify_categories_for_coco = {
+ '/m/03bt1vf': '/m/01g317',
+ '/m/04yx4': '/m/01g317',
+ '/m/05r655': '/m/01g317',
+ '/m/01bl7v': '/m/01g317',
+ '/m/0cnyhnx': '/m/01xq0k1',
+ '/m/01226z': '/m/018xm',
+ '/m/05ctyq': '/m/018xm',
+ '/m/058qzx': '/m/04ctx',
+ '/m/06pcq': '/m/0l515',
+ '/m/03m3pdh': '/m/02crq1',
+ '/m/046dlr': '/m/01x3z',
+ '/m/0h8mzrc': '/m/01x3z',
+}
+
+
+top_300_classes_plus_coco_compatibility = [
+ ('Man', 1060962),
+ ('Clothing', 986610),
+ ('Tree', 748162),
+ ('Woman', 611896),
+ ('Person', 610294),
+ ('Human face', 442948),
+ ('Girl', 175399),
+ ('Building', 162147),
+ ('Car', 159135),
+ ('Plant', 155704),
+ ('Human body', 137073),
+ ('Flower', 133128),
+ ('Window', 127485),
+ ('Human arm', 118380),
+ ('House', 114365),
+ ('Wheel', 111684),
+ ('Suit', 99054),
+ ('Human hair', 98089),
+ ('Human head', 92763),
+ ('Chair', 88624),
+ ('Boy', 79849),
+ ('Table', 73699),
+ ('Jeans', 57200),
+ ('Tire', 55725),
+ ('Skyscraper', 53321),
+ ('Food', 52400),
+ ('Footwear', 50335),
+ ('Dress', 50236),
+ ('Human leg', 47124),
+ ('Toy', 46636),
+ ('Tower', 45605),
+ ('Boat', 43486),
+ ('Land vehicle', 40541),
+ ('Bicycle wheel', 34646),
+ ('Palm tree', 33729),
+ ('Fashion accessory', 32914),
+ ('Glasses', 31940),
+ ('Bicycle', 31409),
+ ('Furniture', 30656),
+ ('Sculpture', 29643),
+ ('Bottle', 27558),
+ ('Dog', 26980),
+ ('Snack', 26796),
+ ('Human hand', 26664),
+ ('Bird', 25791),
+ ('Book', 25415),
+ ('Guitar', 24386),
+ ('Jacket', 23998),
+ ('Poster', 22192),
+ ('Dessert', 21284),
+ ('Baked goods', 20657),
+ ('Drink', 19754),
+ ('Flag', 18588),
+ ('Houseplant', 18205),
+ ('Tableware', 17613),
+ ('Airplane', 17218),
+ ('Door', 17195),
+ ('Sports uniform', 17068),
+ ('Shelf', 16865),
+ ('Drum', 16612),
+ ('Vehicle', 16542),
+ ('Microphone', 15269),
+ ('Street light', 14957),
+ ('Cat', 14879),
+ ('Fruit', 13684),
+ ('Fast food', 13536),
+ ('Animal', 12932),
+ ('Vegetable', 12534),
+ ('Train', 12358),
+ ('Horse', 11948),
+ ('Flowerpot', 11728),
+ ('Motorcycle', 11621),
+ ('Fish', 11517),
+ ('Desk', 11405),
+ ('Helmet', 10996),
+ ('Truck', 10915),
+ ('Bus', 10695),
+ ('Hat', 10532),
+ ('Auto part', 10488),
+ ('Musical instrument', 10303),
+ ('Sunglasses', 10207),
+ ('Picture frame', 10096),
+ ('Sports equipment', 10015),
+ ('Shorts', 9999),
+ ('Wine glass', 9632),
+ ('Duck', 9242),
+ ('Wine', 9032),
+ ('Rose', 8781),
+ ('Tie', 8693),
+ ('Butterfly', 8436),
+ ('Beer', 7978),
+ ('Cabinetry', 7956),
+ ('Laptop', 7907),
+ ('Insect', 7497),
+ ('Goggles', 7363),
+ ('Shirt', 7098),
+ ('Dairy Product', 7021),
+ ('Marine invertebrates', 7014),
+ ('Cattle', 7006),
+ ('Trousers', 6903),
+ ('Van', 6843),
+ ('Billboard', 6777),
+ ('Balloon', 6367),
+ ('Human nose', 6103),
+ ('Tent', 6073),
+ ('Camera', 6014),
+ ('Doll', 6002),
+ ('Coat', 5951),
+ ('Mobile phone', 5758),
+ ('Swimwear', 5729),
+ ('Strawberry', 5691),
+ ('Stairs', 5643),
+ ('Goose', 5599),
+ ('Umbrella', 5536),
+ ('Cake', 5508),
+ ('Sun hat', 5475),
+ ('Bench', 5310),
+ ('Bookcase', 5163),
+ ('Bee', 5140),
+ ('Computer monitor', 5078),
+ ('Hiking equipment', 4983),
+ ('Office building', 4981),
+ ('Coffee cup', 4748),
+ ('Curtain', 4685),
+ ('Plate', 4651),
+ ('Box', 4621),
+ ('Tomato', 4595),
+ ('Coffee table', 4529),
+ ('Office supplies', 4473),
+ ('Maple', 4416),
+ ('Muffin', 4365),
+ ('Cocktail', 4234),
+ ('Castle', 4197),
+ ('Couch', 4134),
+ ('Pumpkin', 3983),
+ ('Computer keyboard', 3960),
+ ('Human mouth', 3926),
+ ('Christmas tree', 3893),
+ ('Mushroom', 3883),
+ ('Swimming pool', 3809),
+ ('Pastry', 3799),
+ ('Lavender (Plant)', 3769),
+ ('Football helmet', 3732),
+ ('Bread', 3648),
+ ('Traffic sign', 3628),
+ ('Common sunflower', 3597),
+ ('Television', 3550),
+ ('Bed', 3525),
+ ('Cookie', 3485),
+ ('Fountain', 3484),
+ ('Paddle', 3447),
+ ('Bicycle helmet', 3429),
+ ('Porch', 3420),
+ ('Deer', 3387),
+ ('Fedora', 3339),
+ ('Canoe', 3338),
+ ('Carnivore', 3266),
+ ('Bowl', 3202),
+ ('Human eye', 3166),
+ ('Ball', 3118),
+ ('Pillow', 3077),
+ ('Salad', 3061),
+ ('Beetle', 3060),
+ ('Orange', 3050),
+ ('Drawer', 2958),
+ ('Platter', 2937),
+ ('Elephant', 2921),
+ ('Seafood', 2921),
+ ('Monkey', 2915),
+ ('Countertop', 2879),
+ ('Watercraft', 2831),
+ ('Helicopter', 2805),
+ ('Kitchen appliance', 2797),
+ ('Personal flotation device', 2781),
+ ('Swan', 2739),
+ ('Lamp', 2711),
+ ('Boot', 2695),
+ ('Bronze sculpture', 2693),
+ ('Chicken', 2677),
+ ('Taxi', 2643),
+ ('Juice', 2615),
+ ('Cowboy hat', 2604),
+ ('Apple', 2600),
+ ('Tin can', 2590),
+ ('Necklace', 2564),
+ ('Ice cream', 2560),
+ ('Human beard', 2539),
+ ('Coin', 2536),
+ ('Candle', 2515),
+ ('Cart', 2512),
+ ('High heels', 2441),
+ ('Weapon', 2433),
+ ('Handbag', 2406),
+ ('Penguin', 2396),
+ ('Rifle', 2352),
+ ('Violin', 2336),
+ ('Skull', 2304),
+ ('Lantern', 2285),
+ ('Scarf', 2269),
+ ('Saucer', 2225),
+ ('Sheep', 2215),
+ ('Vase', 2189),
+ ('Lily', 2180),
+ ('Mug', 2154),
+ ('Parrot', 2140),
+ ('Human ear', 2137),
+ ('Sandal', 2115),
+ ('Lizard', 2100),
+ ('Kitchen & dining room table', 2063),
+ ('Spider', 1977),
+ ('Coffee', 1974),
+ ('Goat', 1926),
+ ('Squirrel', 1922),
+ ('Cello', 1913),
+ ('Sushi', 1881),
+ ('Tortoise', 1876),
+ ('Pizza', 1870),
+ ('Studio couch', 1864),
+ ('Barrel', 1862),
+ ('Cosmetics', 1841),
+ ('Moths and butterflies', 1841),
+ ('Convenience store', 1817),
+ ('Watch', 1792),
+ ('Home appliance', 1786),
+ ('Harbor seal', 1780),
+ ('Luggage and bags', 1756),
+ ('Vehicle registration plate', 1754),
+ ('Shrimp', 1751),
+ ('Jellyfish', 1730),
+ ('French fries', 1723),
+ ('Egg (Food)', 1698),
+ ('Football', 1697),
+ ('Musical keyboard', 1683),
+ ('Falcon', 1674),
+ ('Candy', 1660),
+ ('Medical equipment', 1654),
+ ('Eagle', 1651),
+ ('Dinosaur', 1634),
+ ('Surfboard', 1630),
+ ('Tank', 1628),
+ ('Grape', 1624),
+ ('Lion', 1624),
+ ('Owl', 1622),
+ ('Ski', 1613),
+ ('Waste container', 1606),
+ ('Frog', 1591),
+ ('Sparrow', 1585),
+ ('Rabbit', 1581),
+ ('Pen', 1546),
+ ('Sea lion', 1537),
+ ('Spoon', 1521),
+ ('Sink', 1512),
+ ('Teddy bear', 1507),
+ ('Bull', 1495),
+ ('Sofa bed', 1490),
+ ('Dragonfly', 1479),
+ ('Brassiere', 1478),
+ ('Chest of drawers', 1472),
+ ('Aircraft', 1466),
+ ('Human foot', 1463),
+ ('Pig', 1455),
+ ('Fork', 1454),
+ ('Antelope', 1438),
+ ('Tripod', 1427),
+ ('Tool', 1424),
+ ('Cheese', 1422),
+ ('Lemon', 1397),
+ ('Hamburger', 1393),
+ ('Dolphin', 1390),
+ ('Mirror', 1390),
+ ('Marine mammal', 1387),
+ ('Giraffe', 1385),
+ ('Snake', 1368),
+ ('Gondola', 1364),
+ ('Wheelchair', 1360),
+ ('Piano', 1358),
+ ('Cupboard', 1348),
+ ('Banana', 1345),
+ ('Trumpet', 1335),
+ ('Lighthouse', 1333),
+ ('Invertebrate', 1317),
+ ('Carrot', 1268),
+ ('Sock', 1260),
+ ('Tiger', 1241),
+ ('Camel', 1224),
+ ('Parachute', 1224),
+ ('Bathroom accessory', 1223),
+ ('Earrings', 1221),
+ ('Headphones', 1218),
+ ('Skirt', 1198),
+ ('Skateboard', 1190),
+ ('Sandwich', 1148),
+ ('Saxophone', 1141),
+ ('Goldfish', 1136),
+ ('Stool', 1104),
+ ('Traffic light', 1097),
+ ('Shellfish', 1081),
+ ('Backpack', 1079),
+ ('Sea turtle', 1078),
+ ('Cucumber', 1075),
+ ('Tea', 1051),
+ ('Toilet', 1047),
+ ('Roller skates', 1040),
+ ('Mule', 1039),
+ ('Bust', 1031),
+ ('Broccoli', 1030),
+ ('Crab', 1020),
+ ('Oyster', 1019),
+ ('Cannon', 1012),
+ ('Zebra', 1012),
+ ('French horn', 1008),
+ ('Grapefruit', 998),
+ ('Whiteboard', 997),
+ ('Zucchini', 997),
+ ('Crocodile', 992),
+
+ ('Clock', 960),
+ ('Wall clock', 958),
+
+ ('Doughnut', 869),
+ ('Snail', 868),
+
+ ('Baseball glove', 859),
+
+ ('Panda', 830),
+ ('Tennis racket', 830),
+
+ ('Pear', 652),
+
+ ('Bagel', 617),
+ ('Oven', 616),
+ ('Ladybug', 615),
+ ('Shark', 615),
+ ('Polar bear', 614),
+ ('Ostrich', 609),
+
+ ('Hot dog', 473),
+ ('Microwave oven', 467),
+ ('Fire hydrant', 20),
+ ('Stop sign', 20),
+ ('Parking meter', 20),
+ ('Bear', 20),
+ ('Flying disc', 20),
+ ('Snowboard', 20),
+ ('Tennis ball', 20),
+ ('Kite', 20),
+ ('Baseball bat', 20),
+ ('Kitchen knife', 20),
+ ('Knife', 20),
+ ('Submarine sandwich', 20),
+ ('Computer mouse', 20),
+ ('Remote control', 20),
+ ('Toaster', 20),
+ ('Sink', 20),
+ ('Refrigerator', 20),
+ ('Alarm clock', 20),
+ ('Wall clock', 20),
+ ('Scissors', 20),
+ ('Hair dryer', 20),
+ ('Toothbrush', 20),
+ ('Suitcase', 20)
+]
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py
new file mode 100644
index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py
@@ -0,0 +1,91 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class SegmentationBase(Dataset):
+ def __init__(self,
+ data_csv, data_root, segmentation_root,
+ size=None, random_crop=False, interpolation="bicubic",
+ n_labels=182, shift_segmentation=False,
+ ):
+ self.n_labels = n_labels
+ self.shift_segmentation = shift_segmentation
+ self.data_csv = data_csv
+ self.data_root = data_root
+ self.segmentation_root = segmentation_root
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
+ for l in self.image_paths]
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ assert segmentation.mode == "L", segmentation.mode
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.shift_segmentation:
+ # used to support segmentations containing unlabeled==255 label
+ segmentation = segmentation+1
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image,
+ mask=segmentation
+ )
+ else:
+ processed = {"image": image,
+ "mask": segmentation
+ }
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/sflckr_examples.txt",
+ data_root="data/sflckr_images",
+ segmentation_root="data/sflckr_segmentations",
+ size=size, random_crop=random_crop, interpolation=interpolation)
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/data/utils.py b/src/custom_controlnet_aux/diffusion_edge/taming/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8be640a37f4fcba5d7bb863b8913ba829ee09a85
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/data/utils.py
@@ -0,0 +1,169 @@
+import collections
+import os
+import tarfile
+import urllib
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation
+from torch._six import string_classes
+from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
+from tqdm import tqdm
+
+
+def unpack(path):
+ if path.endswith("tar.gz"):
+ with tarfile.open(path, "r:gz") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("tar"):
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("zip"):
+ with zipfile.ZipFile(path, "r") as f:
+ f.extractall(path=os.path.split(path)[0])
+ else:
+ raise NotImplementedError(
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
+ )
+
+
+def reporthook(bar):
+ """tqdm progress bar for downloads."""
+
+ def hook(b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ bar.total = tsize
+ bar.update(b * bsize - bar.n)
+
+ return hook
+
+
+def get_root(name):
+ base = "data/"
+ root = os.path.join(base, name)
+ os.makedirs(root, exist_ok=True)
+ return root
+
+
+def is_prepared(root):
+ return Path(root).joinpath(".ready").exists()
+
+
+def mark_prepared(root):
+ Path(root).joinpath(".ready").touch()
+
+
+def prompt_download(file_, source, target_dir, content_dir=None):
+ targetpath = os.path.join(target_dir, file_)
+ while not os.path.exists(targetpath):
+ if content_dir is not None and os.path.exists(
+ os.path.join(target_dir, content_dir)
+ ):
+ break
+ print(
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
+ )
+ if content_dir is not None:
+ print(
+ "Or place its content into '{}'.".format(
+ os.path.join(target_dir, content_dir)
+ )
+ )
+ input("Press Enter when done...")
+ return targetpath
+
+
+def download_url(file_, url, target_dir):
+ targetpath = os.path.join(target_dir, file_)
+ os.makedirs(target_dir, exist_ok=True)
+ with tqdm(
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
+ ) as bar:
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
+ return targetpath
+
+
+def download_urls(urls, target_dir):
+ paths = dict()
+ for fname, url in urls.items():
+ outpath = download_url(fname, url, target_dir)
+ paths[fname] = outpath
+ return paths
+
+
+def quadratic_crop(x, bbox, alpha=1.0):
+ """bbox is xmin, ymin, xmax, ymax"""
+ im_h, im_w = x.shape[:2]
+ bbox = np.array(bbox, dtype=np.float32)
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ l = int(alpha * max(w, h))
+ l = max(l, 2)
+
+ required_padding = -1 * min(
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
+ )
+ required_padding = int(np.ceil(required_padding))
+ if required_padding > 0:
+ padding = [
+ [required_padding, required_padding],
+ [required_padding, required_padding],
+ ]
+ padding += [[0, 0]] * (len(x.shape) - 2)
+ x = np.pad(x, padding, "reflect")
+ center = center[0] + required_padding, center[1] + required_padding
+ xmin = int(center[0] - l / 2)
+ ymin = int(center[1] - l / 2)
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
+
+
+def custom_collate(batch):
+ r"""source: pytorch 1.9.0, only one modification to original code """
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return custom_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
+ return batch # added
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = zip(*batch)
+ return [custom_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/autoencoder/lpips/vgg.pth b/src/custom_controlnet_aux/diffusion_edge/taming/modules/autoencoder/lpips/vgg.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/autoencoder/lpips/vgg.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
+size 7289
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py
@@ -0,0 +1,776 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VUNet(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ in_channels, c_channels,
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(c_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.z_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, z):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ z = self.z_in(z)
+ h = torch.cat((h,z),dim=1)
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab2cc89aa2d9478fcfd60c80b91a1f7aa829e4e
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py
@@ -0,0 +1,131 @@
+import functools
+import torch.nn as nn
+
+
+from custom_controlnet_aux.diffusion_edge.taming.modules.util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+class NLayerDiscriminator2(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator2, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm3d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm3d
+ else:
+ use_bias = norm_layer != nn.BatchNorm3d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2,
+ padding=padw, bias=use_bias, groups=8),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1,
+ padding=padw, bias=use_bias, groups=8),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw),
+ # nn.Sigmoid()
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+if __name__ == "__main__":
+ import torch
+ model = NLayerDiscriminator2(input_nc=3, ndf=64, n_layers=3)
+ x = torch.rand(1, 3, 64, 64, 64)
+ with torch.no_grad():
+ y = model(x)
+ pause = 0
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d03a59762ab71d59bb7c3f8977ee0fd459d0bc
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py
@@ -0,0 +1,2 @@
+from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import DummyLoss
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..32666db18cb2e89e65b8e905028aaaed81ccb017
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py
@@ -0,0 +1,126 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from collections import namedtuple
+
+from .util import get_ckpt_path
+
+from custom_controlnet_aux.util import custom_torch_download
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=False, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=False):
+ super(vgg16, self).__init__()
+ vgg16_model = models.vgg16(pretrained=pretrained)
+ vgg16_model.load_state_dict(torch.load(custom_torch_download(filename="vgg16-397923af.pth")), strict=True)
+ vgg_pretrained_features = vgg16_model.features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
+ loss = bce_loss + self.codebook_weight*qloss
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f01102c77e9793197059726c0165afb8d81ace
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from custom_controlnet_aux.diffusion_edge.taming.modules.losses.lpips import LPIPS
+from custom_controlnet_aux.diffusion_edge.taming.modules.discriminator.model import NLayerDiscriminator, weights_init, NLayerDiscriminator2
+
+
+class DummyLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train"):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+ def __init__(self, n_embed, down_factor):
+ self.n_embed = n_embed
+ self.down_factor = down_factor
+
+ def eval(self):
+ return self
+
+ def encode(self, c):
+ """fake vqmodel interface"""
+ assert 0.0 <= c.min() and c.max() <= 1.0
+ b,ch,h,w = c.shape
+ assert ch == 1
+
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+ mode="area")
+ c = c.clamp(0.0, 1.0)
+ c = self.n_embed*c
+ c_quant = c.round()
+ c_ind = c_quant.to(dtype=torch.long)
+
+ info = None, None, c_ind
+ return c_quant, None, info
+
+ def decode(self, c):
+ c = c/self.n_embed
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+ mode="nearest")
+ return c
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/util.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/util.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+
+
+def count_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class Labelator(AbstractEncoder):
+ """Net2Net Interface for Class-Conditional Model"""
+ def __init__(self, n_classes, quantize_interface=True):
+ super().__init__()
+ self.n_classes = n_classes
+ self.quantize_interface = quantize_interface
+
+ def encode(self, c):
+ c = c[:,None]
+ if self.quantize_interface:
+ return c, None, [None, None, c.long()]
+ return c
+
+
+class SOSProvider(AbstractEncoder):
+ # for unconditional training
+ def __init__(self, sos_token, quantize_interface=True):
+ super().__init__()
+ self.sos_token = sos_token
+ self.quantize_interface = quantize_interface
+
+ def encode(self, x):
+ # get batch size from data and replicate sos_token
+ c = torch.ones(x.shape[0], 1)*self.sos_token
+ c = c.long().to(x.device)
+ if self.quantize_interface:
+ return c, None, [None, None, c]
+ return c
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py b/src/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py
@@ -0,0 +1,445 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = num_tokens
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
diff --git a/src/custom_controlnet_aux/diffusion_edge/taming/util.py b/src/custom_controlnet_aux/diffusion_edge/taming/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7
--- /dev/null
+++ b/src/custom_controlnet_aux/diffusion_edge/taming/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+
diff --git a/src/custom_controlnet_aux/dsine/LICENSE b/src/custom_controlnet_aux/dsine/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..2b677c1a1ec5f491a189df741896bc7207093c8f
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/LICENSE
@@ -0,0 +1,230 @@
+DSINE SOFTWARE
+
+LICENCE AGREEMENT
+
+WE (Imperial College of Science, Technology and Medicine, (“Imperial College
+London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY
+ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING
+AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE.
+BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE
+TERMS OF THE AGREEMENT.
+
+SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS)
+
+1. This Agreement pertains to a worldwide, non-exclusive, temporary, fully
+paid-up, royalty free, non-transferable, non-sub- licensable licence (the
+“Licence”) to use the elastic fusion source code, including any modification,
+part or derivative (the “Software”).
+
+Ownership and Licence. Your rights to use and download the Software onto your
+computer, and all other copies that You are authorised to make, are specified
+in this Agreement. However, we (or our licensors) retain all rights, including
+but not limited to all copyright and other intellectual property rights
+anywhere in the world, in the Software not expressly granted to You in this
+Agreement.
+
+2. Permitted use of the Licence:
+
+(a) You may download and install the Software onto one computer or server for
+use in accordance with Clause 2(b) of this Agreement provided that You ensure
+that the Software is not accessible by other users unless they have themselves
+accepted the terms of this licence agreement.
+
+(b) You may use the Software solely for non-commercial, internal or academic
+research purposes and only in accordance with the terms of this Agreement. You
+may not use the Software for commercial purposes, including but not limited to
+(1) integration of all or part of the source code or the Software into a
+product for sale or licence by or on behalf of You to third parties or (2) use
+of the Software or any derivative of it for research to develop software
+products for sale or licence to a third party or (3) use of the Software or any
+derivative of it for research to develop non-software products for sale or
+licence to a third party, or (4) use of the Software to provide any service to
+an external organisation for which payment is received.
+
+Should You wish to use the Software for commercial purposes, You shall
+email researchcontracts.engineering@imperial.ac.uk .
+
+(c) Right to Copy. You may copy the Software for back-up and archival purposes,
+provided that each copy is kept in your possession and provided You reproduce
+our copyright notice (set out in Schedule 1) on each copy.
+
+(d) Transfer and sub-licensing. You may not rent, lend, or lease the Software
+and You may not transmit, transfer or sub-license this licence to use the
+Software or any of your rights or obligations under this Agreement to another
+party.
+
+(e) Identity of Licensee. The licence granted herein is personal to You. You
+shall not permit any third party to access, modify or otherwise use the
+Software nor shall You access modify or otherwise use the Software on behalf of
+any third party. If You wish to obtain a licence for mutiple users or a site
+licence for the Software please contact us
+at researchcontracts.engineering@imperial.ac.uk .
+
+(f) Publications and presentations. You may make public, results or data
+obtained from, dependent on or arising from research carried out using the
+Software, provided that any such presentation or publication identifies the
+Software as the source of the results or the data, including the Copyright
+Notice given in each element of the Software, and stating that the Software has
+been made available for use by You under licence from Imperial College London
+and You provide a copy of any such publication to Imperial College London.
+
+3. Prohibited Uses. You may not, without written permission from us
+at researchcontracts.engineering@imperial.ac.uk :
+
+(a) Use, copy, modify, merge, or transfer copies of the Software or any
+documentation provided by us which relates to the Software except as provided
+in this Agreement;
+
+(b) Use any back-up or archival copies of the Software (or allow anyone else to
+use such copies) for any purpose other than to replace the original copy in the
+event it is destroyed or becomes defective; or
+
+(c) Disassemble, decompile or "unlock", reverse translate, or in any manner
+decode the Software for any reason.
+
+4. Warranty Disclaimer
+
+(a) Disclaimer. The Software has been developed for research purposes only. You
+acknowledge that we are providing the Software to You under this licence
+agreement free of charge and on condition that the disclaimer set out below
+shall apply. We do not represent or warrant that the Software as to: (i) the
+quality, accuracy or reliability of the Software; (ii) the suitability of the
+Software for any particular use or for use under any specific conditions; and
+(iii) whether use of the Software will infringe third-party rights.
+
+You acknowledge that You have reviewed and evaluated the Software to determine
+that it meets your needs and that You assume all responsibility and liability
+for determining the suitability of the Software as fit for your particular
+purposes and requirements. Subject to Clause 4(b), we exclude and expressly
+disclaim all express and implied representations, warranties, conditions and
+terms not stated herein (including the implied conditions or warranties of
+satisfactory quality, merchantable quality, merchantability and fitness for
+purpose).
+
+(b) Savings. Some jurisdictions may imply warranties, conditions or terms or
+impose obligations upon us which cannot, in whole or in part, be excluded,
+restricted or modified or otherwise do not allow the exclusion of implied
+warranties, conditions or terms, in which case the above warranty disclaimer
+and exclusion will only apply to You to the extent permitted in the relevant
+jurisdiction and does not in any event exclude any implied warranties,
+conditions or terms which may not under applicable law be excluded.
+
+(c) Imperial College London disclaims all responsibility for the use which is
+made of the Software and any liability for the outcomes arising from using the
+Software.
+
+5. Limitation of Liability
+
+(a) You acknowledge that we are providing the Software to You under this
+licence agreement free of charge and on condition that the limitation of
+liability set out below shall apply. Accordingly, subject to Clause 5(b), we
+exclude all liability whether in contract, tort, negligence or otherwise, in
+respect of the Software and/or any related documentation provided to You by us
+including, but not limited to, liability for loss or corruption of data, loss
+of contracts, loss of income, loss of profits, loss of cover and any
+consequential or indirect loss or damage of any kind arising out of or in
+connection with this licence agreement, however caused. This exclusion shall
+apply even if we have been advised of the possibility of such loss or damage.
+
+(b) You agree to indemnify Imperial College London and hold it harmless from
+and against any and all claims, damages and liabilities asserted by third
+parties (including claims for negligence) which arise directly or indirectly
+from the use of the Software or any derivative of it or the sale of any
+products based on the Software. You undertake to make no liability claim
+against any employee, student, agent or appointee of Imperial College London,
+in connection with this Licence or the Software.
+
+(c) Nothing in this Agreement shall have the effect of excluding or limiting
+our statutory liability.
+
+(d) Some jurisdictions do not allow these limitations or exclusions either
+wholly or in part, and, to that extent, they may not apply to you. Nothing in
+this licence agreement will affect your statutory rights or other relevant
+statutory provisions which cannot be excluded, restricted or modified, and its
+terms and conditions must be read and construed subject to any such statutory
+rights and/or provisions.
+
+6. Confidentiality. You agree not to disclose any confidential information
+provided to You by us pursuant to this Agreement to any third party without our
+prior written consent. The obligations in this Clause 6 shall survive the
+termination of this Agreement for any reason.
+
+7. Termination.
+
+(a) We may terminate this licence agreement and your right to use the Software
+at any time with immediate effect upon written notice to You.
+
+(b) This licence agreement and your right to use the Software automatically
+terminate if You:
+
+ (i) fail to comply with any provisions of this Agreement; or
+
+ (ii) destroy the copies of the Software in your possession, or voluntarily
+ return the Software to us.
+
+(c) Upon termination You will destroy all copies of the Software.
+
+(d) Otherwise, the restrictions on your rights to use the Software will expire
+10 (ten) years after first use of the Software under this licence agreement.
+
+8. Miscellaneous Provisions.
+
+(a) This Agreement will be governed by and construed in accordance with the
+substantive laws of England and Wales whose courts shall have exclusive
+jurisdiction over all disputes which may arise between us.
+
+(b) This is the entire agreement between us relating to the Software, and
+supersedes any prior purchase order, communications, advertising or
+representations concerning the Software.
+
+(c) No change or modification of this Agreement will be valid unless it is in
+writing, and is signed by us.
+
+(d) The unenforceability or invalidity of any part of this Agreement will not
+affect the enforceability or validity of the remaining parts.
+
+BSD Elements of the Software
+
+For BSD elements of the Software, the following terms shall apply:
+Copyright as indicated in the header of the individual element of the Software.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors
+may be used to endorse or promote products derived from this software without
+specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+SCHEDULE 1
+
+The Software
+
+DSINE is a framework for estimating surface normals from a single image. It is based on the techniques described in the following publication:
+
+ • Gwangbin Bae, Andrew J. Davison. Rethinking Inductive Biases for Surface Normal Estimation. CVPR, 2024
+_________________________
+
+Acknowledgments
+
+If you use the software, you should reference the following paper in any publication:
+
+ • Gwangbin Bae, Andrew J. Davison. Rethinking Inductive Biases for Surface Normal Estimation. CVPR, 2024
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dsine/__init__.py b/src/custom_controlnet_aux/dsine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3cf8aa24bffd35d1846d69c55ace52f8cf9ccf3
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/__init__.py
@@ -0,0 +1,222 @@
+import os
+import types
+import warnings
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from einops import rearrange
+from PIL import Image
+from huggingface_hub import hf_hub_download
+
+from .models.dsine_arch import DSINE
+from .utils.utils import get_intrins_from_fov
+
+# Local constants
+DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge"
+
+# Local utility functions
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+def safer_memory(x):
+ return np.ascontiguousarray(x.copy()).copy()
+
+def resize_image_with_pad(input_image, resolution, upscale_method="INTER_CUBIC", skip_hwc3=False, mode='edge'):
+ if skip_hwc3:
+ img = input_image
+ else:
+ img = HWC3(input_image)
+ H_raw, W_raw, _ = img.shape
+ if resolution == 0:
+ return img, lambda x: x
+ k = float(resolution) / float(min(H_raw, W_raw))
+ H_target = int(np.round(float(H_raw) * k))
+ W_target = int(np.round(float(W_raw) * k))
+
+ # Get upscale method
+ upscale_methods = {"INTER_NEAREST": cv2.INTER_NEAREST, "INTER_LINEAR": cv2.INTER_LINEAR,
+ "INTER_AREA": cv2.INTER_AREA, "INTER_CUBIC": cv2.INTER_CUBIC,
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4}
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
+
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
+
+ def remove_pad(x):
+ return safer_memory(x[:H_target, :W_target, ...])
+
+ return safer_memory(img_padded), remove_pad
+
+def common_input_validate(input_image, output_type, **kwargs):
+ if "img" in kwargs:
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
+ input_image = kwargs.pop("img")
+
+ if "return_pil" in kwargs:
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
+ output_type = "pil" if kwargs["return_pil"] else "np"
+
+ if type(output_type) is bool:
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
+ if output_type:
+ output_type = "pil"
+
+ if input_image is None:
+ raise ValueError("input_image must be defined.")
+
+ if not isinstance(input_image, np.ndarray):
+ input_image = np.array(input_image, dtype=np.uint8)
+ output_type = output_type or "pil"
+ else:
+ output_type = output_type or "np"
+
+ return (input_image, output_type)
+
+def custom_hf_download(pretrained_model_or_path, filename, subfolder=''):
+ """Download model files from HuggingFace Hub"""
+ annotator_ckpts_path = os.path.join(Path(__file__).parents[3], 'ckpts')
+ local_dir = os.path.join(annotator_ckpts_path, pretrained_model_or_path)
+ model_path = Path(local_dir).joinpath(*subfolder.split('/'), filename).__str__()
+
+ if not os.path.exists(model_path):
+ print(f"Downloading {filename} from {pretrained_model_or_path}")
+ model_path = hf_hub_download(
+ repo_id=pretrained_model_or_path,
+ filename=filename,
+ subfolder=subfolder,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False
+ )
+
+ print(f"model_path is {model_path}")
+ return model_path
+
+# load model
+def load_checkpoint(fpath, model):
+ ckpt = torch.load(fpath, map_location='cpu')['model']
+
+ load_dict = {}
+ for k, v in ckpt.items():
+ if k.startswith('module.'):
+ k_ = k.replace('module.', '')
+ load_dict[k_] = v
+ else:
+ load_dict[k] = v
+
+ # Load compatible weights only
+ model_state = model.state_dict()
+ compatible_dict = {}
+ skipped_keys = []
+
+ for k, v in load_dict.items():
+ if k in model_state:
+ if model_state[k].shape == v.shape:
+ compatible_dict[k] = v
+ else:
+ skipped_keys.append(f"{k}: checkpoint {v.shape} vs model {model_state[k].shape}")
+ else:
+ skipped_keys.append(f"{k}: not found in model")
+
+ print(f"Loading checkpoint: {len(compatible_dict)} compatible, {len(skipped_keys)} skipped")
+ if skipped_keys:
+ print("Skipped keys with shape mismatches:")
+ for key in skipped_keys[:5]: # Show first 5 mismatches
+ print(f" {key}")
+ if len(skipped_keys) > 5:
+ print(f" ... and {len(skipped_keys) - 5} more")
+
+ model.load_state_dict(compatible_dict, strict=False)
+ return model
+
+def get_pad(orig_H, orig_W):
+ if orig_W % 64 == 0:
+ l = 0
+ r = 0
+ else:
+ new_W = 64 * ((orig_W // 64) + 1)
+ l = (new_W - orig_W) // 2
+ r = (new_W - orig_W) - l
+
+ if orig_H % 64 == 0:
+ t = 0
+ b = 0
+ else:
+ new_H = 64 * ((orig_H // 64) + 1)
+ t = (new_H - orig_H) // 2
+ b = (new_H - orig_H) - t
+ return l, r, t, b
+
+class DsineDetector:
+ def __init__(self, model):
+ self.model = model
+ self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="dsine.pt"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ model = DSINE()
+ model = load_checkpoint(model_path, model)
+ model.eval()
+
+ return cls(model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.model.pixel_coords = self.model.pixel_coords.to(device)
+ self.device = device
+ return self
+
+
+ def __call__(self, input_image, fov=60.0, iterations=5, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ self.model.num_iter = iterations
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ orig_H, orig_W = input_image.shape[:2]
+ l, r, t, b = get_pad(orig_H, orig_W)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method, mode="constant")
+ with torch.no_grad():
+ input_image = torch.from_numpy(input_image).float().to(self.device)
+ input_image = input_image / 255.0
+ input_image = rearrange(input_image, 'h w c -> 1 c h w')
+ input_image = self.norm(input_image)
+
+ intrins = get_intrins_from_fov(new_fov=fov, H=orig_H, W=orig_W, device=self.device).unsqueeze(0)
+ intrins[:, 0, 2] += l
+ intrins[:, 1, 2] += t
+
+ normal = self.model(input_image, intrins)
+ normal = normal[-1][0]
+ normal = ((normal + 1) * 0.5).clip(0, 1)
+
+ normal = rearrange(normal, 'c h w -> h w c').cpu().numpy()
+ normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = HWC3(normal_image)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dsine/models/dsine_arch.py b/src/custom_controlnet_aux/dsine/models/dsine_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f4cc2bab1a5e95906ea16cce953656eac4b51e4
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/models/dsine_arch.py
@@ -0,0 +1,230 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .submodules import Encoder, ConvGRU, UpSampleBN, UpSampleGN, RayReLU, \
+ convex_upsampling, get_unfold, get_prediction_head, \
+ INPUT_CHANNELS_DICT
+from ..utils.rotation import axis_angle_to_matrix
+
+
+class Decoder(nn.Module):
+ def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8):
+ super(Decoder, self).__init__()
+ input_channels = INPUT_CHANNELS_DICT[B]
+ output_dim, feature_dim, hidden_dim = output_dims
+ features = bottleneck_features = NF
+ self.downsample_ratio = downsample_ratio
+
+ UpSample = UpSampleBN if BN else UpSampleGN
+ self.conv2 = nn.Conv2d(bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0)
+ self.up1 = UpSample(skip_input=features // 1 + input_channels[1] + 2, output_features=features // 2, align_corners=False)
+ self.up2 = UpSample(skip_input=features // 2 + input_channels[2] + 2, output_features=features // 4, align_corners=False)
+
+ # prediction heads
+ i_dim = features // 4
+ h_dim = 128
+ self.normal_head = get_prediction_head(i_dim+2, h_dim, output_dim)
+ self.feature_head = get_prediction_head(i_dim+2, h_dim, feature_dim)
+ self.hidden_head = get_prediction_head(i_dim+2, h_dim, hidden_dim)
+
+ def forward(self, features, uvs):
+ _, _, x_block2, x_block3, x_block4 = None, None, features[5], features[7], features[10] # Skip first two features, use layers 5,7,10
+ uv_32, uv_16, uv_8 = uvs
+
+ x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1))
+ x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1))
+ x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1))
+ x_feat = torch.cat([x_feat, uv_8], dim=1)
+
+ normal = self.normal_head(x_feat)
+ normal = F.normalize(normal, dim=1)
+ f = self.feature_head(x_feat)
+ h = self.hidden_head(x_feat)
+ return normal, f, h
+
+
+class DSINE(nn.Module):
+ def __init__(self):
+ super(DSINE, self).__init__()
+ self.downsample_ratio = 8
+ self.ps = 5 # patch size
+ self.num_iter = 5 # num iterations
+
+ # define encoder
+ self.encoder = Encoder(B=5, pretrained=True)
+
+ # define decoder
+ self.output_dim = output_dim = 3
+ self.feature_dim = feature_dim = 64
+ self.hidden_dim = hidden_dim = 64
+ self.decoder = Decoder([output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False)
+
+ # ray direction-based ReLU
+ self.ray_relu = RayReLU(eps=1e-2)
+
+ # pixel_coords (1, 3, H, W) - adjust dimensions for larger inputs if needed
+ h = 2000
+ w = 2000
+ pixel_coords = np.ones((3, h, w)).astype(np.float32)
+ x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
+ y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
+ pixel_coords[0, :, :] = x_range + 0.5
+ pixel_coords[1, :, :] = y_range + 0.5
+ self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0)
+
+ # define ConvGRU cell
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim+2, ks=self.ps)
+
+ # padding used during NRN
+ self.pad = (self.ps - 1) // 2
+
+ # prediction heads
+ self.prob_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # weights assigned for each nghbr pixel
+ self.xy_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps*2) # rotation axis for each nghbr pixel
+ self.angle_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # rotation angle for each nghbr pixel
+
+ # prediction heads - weights used for upsampling the coarse resolution output
+ self.up_prob_head = get_prediction_head(self.hidden_dim+2, 64, 9 * self.downsample_ratio * self.downsample_ratio)
+
+ def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False):
+ B, _, _ = intrins.shape
+ fu = intrins[:, 0, 0][:,None,None] * (W / orig_W)
+ cu = intrins[:, 0, 2][:,None,None] * (W / orig_W)
+ fv = intrins[:, 1, 1][:,None,None] * (H / orig_H)
+ cv = intrins[:, 1, 2][:,None,None] * (H / orig_H)
+
+ # (B, 2, H, W)
+ ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1)
+ ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu
+ ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv
+
+ if return_uv:
+ return ray[:, :2, :, :]
+ else:
+ return F.normalize(ray, dim=1)
+
+ def upsample(self, h, pred_norm, uv_8):
+ up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1))
+ up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
+ up_pred_norm = F.normalize(up_pred_norm, dim=1)
+ return up_pred_norm
+
+ def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8):
+ B, C, H, W = pred_norm.shape
+ fu = intrins[:, 0, 0][:,None,None,None] * (W / orig_W) # (B, 1, 1, 1)
+ cu = intrins[:, 0, 2][:,None,None,None] * (W / orig_W)
+ fv = intrins[:, 1, 1][:,None,None,None] * (H / orig_H)
+ cv = intrins[:, 1, 2][:,None,None,None] * (H / orig_H)
+
+ h_new = self.gru(h, feat_map)
+
+ # get nghbr prob (B, 1, ps*ps, h, w)
+ nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
+ nghbr_prob = torch.sigmoid(nghbr_prob)
+
+ # get nghbr normals (B, 3, ps*ps, h, w)
+ nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad)
+
+ # get nghbr xy (B, 2, ps*ps, h, w)
+ nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1))
+ nghbr_xs, nghbr_ys = torch.split(nghbr_xys, [self.ps*self.ps, self.ps*self.ps], dim=1)
+ nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1)
+ nghbr_xys = F.normalize(nghbr_xys, dim=1)
+
+ # get nghbr theta (B, 1, ps*ps, h, w)
+ nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
+ nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi
+
+ # get nghbr pixel coord (1, 3, ps*ps, h, w)
+ nghbr_pixel_coord = get_unfold(self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad)
+
+ # nghbr axes (B, 3, ps*ps, h, w)
+ nghbr_axes = torch.zeros_like(nghbr_normals)
+
+ du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w)
+ dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w)
+
+ term_u = (nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu) / fu # (B, ps*ps, h, w)
+ term_v = (nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv) / fv # (B, ps*ps, h, w)
+
+ nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w)
+ ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w)
+ nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w)
+
+ nghbr_delta_z_num = - (du_over_fu * nx + dv_over_fv * ny)
+ nghbr_delta_z_denom = (term_u * nx + term_v * ny + nz)
+ nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8])
+ nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom
+
+ nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u
+ nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v
+ nghbr_axes[:, 2, ...] = nghbr_delta_z
+ nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w)
+
+ # make sure axes are all valid
+ invalid = torch.sum(torch.logical_or(torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)).float(), dim=1) > 0.5 # (B, ps*ps, h, w)
+ nghbr_axes[:, 0, ...][invalid] = 0.0
+ nghbr_axes[:, 1, ...][invalid] = 0.0
+ nghbr_axes[:, 2, ...][invalid] = 0.0
+
+ # nghbr_axes_angle (B, 3, ps*ps, h, w)
+ nghbr_axes_angle = nghbr_axes * nghbr_angle
+ nghbr_axes_angle = nghbr_axes_angle.permute(0, 2, 3, 4, 1) # (B, ps*ps, h, w, 3)
+ nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3)
+
+ # (B, 3, ps*ps, h, w)
+ nghbr_normals_rot = torch.bmm(
+ nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3),
+ nghbr_normals.permute(0, 2, 3, 4, 1).reshape(B * self.ps * self.ps * H * W, 3).unsqueeze(-1)
+ ).reshape(B, self.ps*self.ps, H, W, 3, 1).squeeze(-1).permute(0, 4, 1, 2, 3) # (B, 3, ps*ps, h, w)
+ nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1)
+
+ # ray ReLU
+ nghbr_normals_rot = torch.cat([
+ self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2)
+ for i in range(nghbr_normals_rot.size(2))
+ ], dim=2)
+
+ # (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w)
+ pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W)
+ pred_norm = F.normalize(pred_norm, dim=1)
+
+ up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1))
+ up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
+ up_pred_norm = F.normalize(up_pred_norm, dim=1)
+
+ return h_new, pred_norm, up_pred_norm
+
+
+ def forward(self, img, intrins=None):
+ # Step 1. encoder
+ features = self.encoder(img)
+
+ # Step 2. get uv encoding
+ B, _, orig_H, orig_W = img.shape
+ intrins[:, 0, 2] += 0.5
+ intrins[:, 1, 2] += 0.5
+ uv_32 = self.get_ray(intrins, orig_H//32, orig_W//32, orig_H, orig_W, return_uv=True)
+ uv_16 = self.get_ray(intrins, orig_H//16, orig_W//16, orig_H, orig_W, return_uv=True)
+ uv_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W, return_uv=True)
+ ray_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W)
+
+ # Step 3. decoder - initial prediction
+ pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8))
+ pred_norm = self.ray_relu(pred_norm, ray_8)
+
+ # Step 4. add ray direction encoding
+ feat_map = torch.cat([feat_map, uv_8], dim=1)
+
+ # iterative refinement
+ up_pred_norm = self.upsample(h, pred_norm, uv_8)
+ pred_list = [up_pred_norm]
+ for i in range(self.num_iter):
+ h, pred_norm, up_pred_norm = self.refine(h, feat_map,
+ pred_norm.detach(),
+ intrins, orig_H, orig_W, uv_8, ray_8)
+ pred_list.append(up_pred_norm)
+ return pred_list
+
diff --git a/src/custom_controlnet_aux/dsine/models/submodules/__init__.py b/src/custom_controlnet_aux/dsine/models/submodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9fb0a4590489d9aa8d6bd6c3ff57f23f7f86c6
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/models/submodules/__init__.py
@@ -0,0 +1,176 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+import timm
+
+
+INPUT_CHANNELS_DICT = {
+ 0: [1280, 112, 40, 24, 16],
+ 1: [1280, 112, 40, 24, 16],
+ 2: [1408, 120, 48, 24, 16],
+ 3: [1536, 136, 48, 32, 24],
+ 4: [1792, 160, 56, 32, 24],
+ 5: [2048, 176, 64, None, None], # EfficientNet-B5: features[10,7,5]
+ 6: [2304, 200, 72, 40, 32],
+ 7: [2560, 224, 80, 48, 32]
+}
+
+
+from .standalone_encoder import StandaloneEncoder
+Encoder = StandaloneEncoder
+
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim, input_dim, ks=3):
+ super(ConvGRU, self).__init__()
+ p = (ks - 1) // 2
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p)
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p)
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p)
+
+ def forward(self, h, x):
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz(hx))
+ r = torch.sigmoid(self.convr(hx))
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+ return h
+
+
+class RayReLU(nn.Module):
+ def __init__(self, eps=1e-2):
+ super(RayReLU, self).__init__()
+ self.eps = eps
+
+ def forward(self, pred_norm, ray):
+ # angle between the predicted normal and ray direction
+ cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(1) # (B, 1, H, W)
+
+ # component of pred_norm along view
+ norm_along_view = ray * cos
+
+ # cos should be bigger than eps
+ norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps)
+
+ # difference
+ diff = norm_along_view_relu - norm_along_view
+
+ # updated pred_norm
+ new_pred_norm = pred_norm + diff
+ new_pred_norm = F.normalize(new_pred_norm, dim=1)
+
+ return new_pred_norm
+
+
+class UpSampleBN(nn.Module):
+ def __init__(self, skip_input, output_features, align_corners=True):
+ super(UpSampleBN, self).__init__()
+ self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU(),
+ nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU())
+ self.align_corners = align_corners
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+class Conv2d_WS(nn.Conv2d):
+ """ weight standardization
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, x):
+ weight = self.weight
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
+ keepdim=True).mean(dim=3, keepdim=True)
+ weight = weight - weight_mean
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+ weight = weight / std.expand_as(weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class UpSampleGN(nn.Module):
+ """ UpSample with GroupNorm
+ """
+ def __init__(self, skip_input, output_features, align_corners=True):
+ super(UpSampleGN, self).__init__()
+ self._net = nn.Sequential(Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU(),
+ Conv2d_WS(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU())
+ self.align_corners = align_corners
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+def upsample_via_bilinear(out, up_mask, downsample_ratio):
+ """ bilinear upsampling (up_mask is a dummy variable)
+ """
+ return F.interpolate(out, scale_factor=downsample_ratio, mode='bilinear', align_corners=True)
+
+
+def upsample_via_mask(out, up_mask, downsample_ratio):
+ """ convex upsampling
+ """
+ # out: low-resolution output (B, o_dim, H, W)
+ # up_mask: (B, 9*k*k, H, W)
+ k = downsample_ratio
+
+ N, o_dim, H, W = out.shape
+ up_mask = up_mask.view(N, 1, 9, k, k, H, W)
+ up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
+
+ up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W)
+ up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W)
+ up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W)
+
+ up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k)
+ return up_out.reshape(N, o_dim, k*H, k*W) # (B, 2, kH, kW)
+
+
+def convex_upsampling(out, up_mask, k):
+ # out: low-resolution output (B, C, H, W)
+ # up_mask: (B, 9*k*k, H, W)
+ B, C, H, W = out.shape
+ up_mask = up_mask.view(B, 1, 9, k, k, H, W)
+ up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
+
+ out = F.pad(out, pad=(1,1,1,1), mode='replicate')
+ up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W)
+ up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W)
+
+ up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W)
+ up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k)
+ return up_out.reshape(B, C, k*H, k*W) # (B, C, kH, kW)
+
+
+def get_unfold(pred_norm, ps, pad):
+ B, C, H, W = pred_norm.shape
+ pred_norm = F.pad(pred_norm, pad=(pad,pad,pad,pad), mode='replicate') # (B, C, h, w)
+ pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w)
+ pred_norm_unfold = pred_norm_unfold.view(B, C, ps*ps, H, W) # (B, C, ps*ps, h, w)
+ return pred_norm_unfold
+
+
+def get_prediction_head(input_dim, hidden_dim, output_dim):
+ return nn.Sequential(
+ nn.Conv2d(input_dim, hidden_dim, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(hidden_dim, hidden_dim, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(hidden_dim, output_dim, 1),
+ )
diff --git a/src/custom_controlnet_aux/dsine/models/submodules/standalone_encoder.py b/src/custom_controlnet_aux/dsine/models/submodules/standalone_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e823971e67fee63d7473f908e1e3e4d48dc2e5eb
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/models/submodules/standalone_encoder.py
@@ -0,0 +1,39 @@
+"""
+Standalone DSINE Encoder using EfficientNet-B5 backbone.
+"""
+import torch
+import torch.nn as nn
+import timm
+
+INPUT_CHANNELS_DICT = {
+ 0: [1280, 112, 40, 24, 16],
+ 1: [1280, 112, 40, 24, 16],
+ 2: [1408, 120, 48, 24, 16],
+ 3: [1536, 136, 48, 32, 24],
+ 4: [1792, 160, 56, 32, 24],
+ 5: [2048, 176, 64, None, None], # EfficientNet-B5: features[10,7,5]
+ 6: [2304, 200, 72, 40, 32],
+ 7: [2560, 224, 80, 48, 32]
+}
+
+class StandaloneEncoder(nn.Module):
+ """EfficientNet encoder for DSINE depth estimation."""
+ def __init__(self, B=5, pretrained=True):
+ super(StandaloneEncoder, self).__init__()
+
+ basemodel_name = f'tf_efficientnet_b{B}.ap_in1k'
+ basemodel = timm.create_model(basemodel_name, pretrained=False, num_classes=0)
+ basemodel.global_pool = nn.Identity()
+ basemodel.classifier = nn.Identity()
+
+ self.original_model = basemodel
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if (k == 'blocks'):
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dsine/utils/rotation.py b/src/custom_controlnet_aux/dsine/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8123f900cee404d10dd96e03ce31a681c594e284
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/utils/rotation.py
@@ -0,0 +1,85 @@
+import torch
+import numpy as np
+
+
+# NOTE: from PyTorch3D
+def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
+ )
+ return quaternions
+
+
+# NOTE: from PyTorch3D
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+# NOTE: from PyTorch3D
+def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as axis/angle to rotation matrices.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a tensor of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dsine/utils/utils.py b/src/custom_controlnet_aux/dsine/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..61cd4cca0dd2f1ba05a77d139c90f226fec5af04
--- /dev/null
+++ b/src/custom_controlnet_aux/dsine/utils/utils.py
@@ -0,0 +1,105 @@
+""" utils
+"""
+import os
+import torch
+import numpy as np
+
+
+def load_checkpoint(fpath, model):
+ print('loading checkpoint... {}'.format(fpath))
+
+ ckpt = torch.load(fpath, map_location='cpu')['model']
+
+ load_dict = {}
+ for k, v in ckpt.items():
+ if k.startswith('module.'):
+ k_ = k.replace('module.', '')
+ load_dict[k_] = v
+ else:
+ load_dict[k] = v
+
+ model.load_state_dict(load_dict)
+ print('loading checkpoint... / done')
+ return model
+
+
+def compute_normal_error(pred_norm, gt_norm):
+ pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+ pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
+ pred_error = torch.acos(pred_error) * 180.0 / np.pi
+ pred_error = pred_error.unsqueeze(1) # (B, 1, H, W)
+ return pred_error
+
+
+def compute_normal_metrics(total_normal_errors):
+ total_normal_errors = total_normal_errors.detach().cpu().numpy()
+ num_pixels = total_normal_errors.shape[0]
+
+ metrics = {
+ 'mean': np.average(total_normal_errors),
+ 'median': np.median(total_normal_errors),
+ 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
+ 'a1': 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
+ 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
+ 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
+ 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
+ 'a5': 100.0 * (np.sum(total_normal_errors < 30) / num_pixels)
+ }
+
+ return metrics
+
+
+def pad_input(orig_H, orig_W):
+ if orig_W % 32 == 0:
+ l = 0
+ r = 0
+ else:
+ new_W = 32 * ((orig_W // 32) + 1)
+ l = (new_W - orig_W) // 2
+ r = (new_W - orig_W) - l
+
+ if orig_H % 32 == 0:
+ t = 0
+ b = 0
+ else:
+ new_H = 32 * ((orig_H // 32) + 1)
+ t = (new_H - orig_H) // 2
+ b = (new_H - orig_H) - t
+ return l, r, t, b
+
+
+def get_intrins_from_fov(new_fov, H, W, device):
+ # NOTE: top-left pixel should be (0,0)
+ if W >= H:
+ new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
+ new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
+ else:
+ new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
+ new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
+
+ new_cu = (W / 2.0) - 0.5
+ new_cv = (H / 2.0) - 0.5
+
+ new_intrins = torch.tensor([
+ [new_fu, 0, new_cu ],
+ [0, new_fv, new_cv ],
+ [0, 0, 1 ]
+ ], dtype=torch.float32, device=device)
+
+ return new_intrins
+
+
+def get_intrins_from_txt(intrins_path, device):
+ # NOTE: top-left pixel should be (0,0)
+ with open(intrins_path, 'r') as f:
+ intrins_ = f.readlines()[0].split()[0].split(',')
+ intrins_ = [float(i) for i in intrins_]
+ fx, fy, cx, cy = intrins_
+
+ intrins = torch.tensor([
+ [fx, 0,cx],
+ [ 0,fy,cy],
+ [ 0, 0, 1]
+ ], dtype=torch.float32, device=device)
+
+ return intrins
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/LICENSE b/src/custom_controlnet_aux/dwpose/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/LICENSE
@@ -0,0 +1,108 @@
+OPENPOSE: MULTIPERSON KEYPOINT DETECTION
+SOFTWARE LICENSE AGREEMENT
+ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
+
+BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
+
+This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
+
+RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
+Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
+non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
+
+CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
+
+COPYRIGHT: The Software is owned by Licensor and is protected by United
+States copyright laws and applicable international treaties and/or conventions.
+
+PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
+
+DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
+
+BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
+
+USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
+
+You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
+
+ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
+
+TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
+
+The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
+
+FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
+
+DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
+
+SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
+
+EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
+
+EXPORT REGULATION: Licensee agrees to comply with any and all applicable
+U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
+
+SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
+
+NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
+
+GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
+
+ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
+
+
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
+
+1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014-2017, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright over
+their contributions to Caffe. The project versioning records all such
+contribution and copyright details. If a contributor wants to further mark
+their specific copyright on a particular contribution, they should indicate
+their copyright solely in the commit message of the change when it is
+committed.
+
+LICENSE
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+CONTRIBUTION AGREEMENT
+
+By contributing to the BVLC/caffe repository through pull-request, comment,
+or otherwise, the contributor releases their content to the
+license and copyright terms herein.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/__init__.py b/src/custom_controlnet_aux/dwpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9152befb0a544a91de5a9b131a6fece63020a62
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/__init__.py
@@ -0,0 +1,332 @@
+# Openpose
+# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
+# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
+# 3rd Edited by ControlNet
+# 4th Edited by ControlNet (added face and correct hands)
+# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
+# This preprocessor is licensed by CMU for non-commercial use only.
+
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import json
+import torch
+import numpy as np
+from . import util
+from .body import Body, BodyResult, Keypoint
+from .hand import Hand
+from .face import Face
+from .types import PoseResult, HandResult, FaceResult, AnimalPoseResult
+from huggingface_hub import hf_hub_download
+from .wholebody import Wholebody
+import warnings
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download
+import cv2
+from PIL import Image
+from .animalpose import AnimalPoseImage
+
+from typing import Tuple, List, Callable, Union, Optional
+
+
+def draw_animalposes(animals: list[list[Keypoint]], H: int, W: int) -> np.ndarray:
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
+ for animal_pose in animals:
+ canvas = draw_animalpose(canvas, animal_pose)
+ return canvas
+
+
+def draw_animalpose(canvas: np.ndarray, keypoints: list[Keypoint]) -> np.ndarray:
+ # order of the keypoints for AP10k and a standardized list of colors for limbs
+ keypointPairsList = [
+ (1, 2),
+ (2, 3),
+ (1, 3),
+ (3, 4),
+ (4, 9),
+ (9, 10),
+ (10, 11),
+ (4, 6),
+ (6, 7),
+ (7, 8),
+ (4, 5),
+ (5, 15),
+ (15, 16),
+ (16, 17),
+ (5, 12),
+ (12, 13),
+ (13, 14),
+ ]
+ colorsList = [
+ (255, 255, 255),
+ (100, 255, 100),
+ (150, 255, 255),
+ (100, 50, 255),
+ (50, 150, 200),
+ (0, 255, 255),
+ (0, 150, 0),
+ (0, 0, 255),
+ (0, 0, 150),
+ (255, 50, 255),
+ (255, 0, 255),
+ (255, 0, 0),
+ (150, 0, 0),
+ (255, 255, 100),
+ (0, 150, 0),
+ (255, 255, 0),
+ (150, 150, 150),
+ ] # 16 colors needed
+
+ for ind, (i, j) in enumerate(keypointPairsList):
+ p1 = keypoints[i - 1]
+ p2 = keypoints[j - 1]
+
+ if p1 is not None and p2 is not None:
+ cv2.line(
+ canvas,
+ (int(p1.x), int(p1.y)),
+ (int(p2.x), int(p2.y)),
+ colorsList[ind],
+ 5,
+ )
+ return canvas
+
+
+def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True, xinsr_stick_scaling=False):
+ """
+ Draw the detected poses on an empty canvas.
+
+ Args:
+ poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
+ H (int): The height of the canvas.
+ W (int): The width of the canvas.
+ draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
+ draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
+ draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
+
+ Returns:
+ numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
+ """
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
+
+ for pose in poses:
+ if draw_body:
+ canvas = util.draw_bodypose(canvas, pose.body.keypoints, xinsr_stick_scaling)
+
+ if draw_hand:
+ canvas = util.draw_handpose(canvas, pose.left_hand)
+ canvas = util.draw_handpose(canvas, pose.right_hand)
+
+ if draw_face:
+ canvas = util.draw_facepose(canvas, pose.face)
+
+ return canvas
+
+
+def decode_json_as_poses(
+ pose_json: dict,
+) -> Tuple[List[PoseResult], List[AnimalPoseResult], int, int]:
+ """Decode the json_string complying with the openpose JSON output format
+ to poses that controlnet recognizes.
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
+
+ Args:
+ json_string: The json string to decode.
+
+ Returns:
+ human_poses
+ animal_poses
+ canvas_height
+ canvas_width
+ """
+ height = pose_json["canvas_height"]
+ width = pose_json["canvas_width"]
+
+ def chunks(lst, n):
+ """Yield successive n-sized chunks from lst."""
+ for i in range(0, len(lst), n):
+ yield lst[i : i + n]
+
+ def decompress_keypoints(
+ numbers: Optional[List[float]],
+ ) -> Optional[List[Optional[Keypoint]]]:
+ if not numbers:
+ return None
+
+ assert len(numbers) % 3 == 0
+
+ def create_keypoint(x, y, c):
+ if c < 1.0:
+ return None
+ keypoint = Keypoint(x, y)
+ return keypoint
+
+ return [create_keypoint(x, y, c) for x, y, c in chunks(numbers, n=3)]
+
+ return (
+ [
+ PoseResult(
+ body=BodyResult(
+ keypoints=decompress_keypoints(pose.get("pose_keypoints_2d"))
+ ),
+ left_hand=decompress_keypoints(pose.get("hand_left_keypoints_2d")),
+ right_hand=decompress_keypoints(pose.get("hand_right_keypoints_2d")),
+ face=decompress_keypoints(pose.get("face_keypoints_2d")),
+ )
+ for pose in pose_json.get("people", [])
+ ],
+ [decompress_keypoints(pose) for pose in pose_json.get("animals", [])],
+ height,
+ width,
+ )
+
+
+def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
+ """ Encode the pose as a dict following openpose JSON output format:
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
+ """
+ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
+ if not keypoints:
+ return None
+
+ return [
+ value
+ for keypoint in keypoints
+ for value in (
+ [float(keypoint.x), float(keypoint.y), 1.0]
+ if keypoint is not None
+ else [0.0, 0.0, 0.0]
+ )
+ ]
+
+ return {
+ 'people': [
+ {
+ 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
+ "face_keypoints_2d": compress_keypoints(pose.face),
+ "hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
+ "hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
+ }
+ for pose in poses
+ ],
+ 'canvas_height': canvas_height,
+ 'canvas_width': canvas_width,
+ }
+
+global_cached_dwpose = Wholebody()
+
+class DwposeDetector:
+ """
+ A class for detecting human poses in images using the Dwpose model.
+
+ Attributes:
+ model_dir (str): Path to the directory where the pose models are stored.
+ """
+ def __init__(self, dw_pose_estimation):
+ self.dw_pose_estimation = dw_pose_estimation
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename=None, pose_filename=None, torchscript_device="cuda"):
+ global global_cached_dwpose
+ pretrained_det_model_or_path = pretrained_det_model_or_path or pretrained_model_or_path
+
+ pose_filename = pose_filename or "dw-ll_ucoco_384.onnx"
+
+ det_model_path = None
+ if det_filename is not None:
+ det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
+ pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
+
+ print(f"\nDWPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
+ if global_cached_dwpose.det is None or global_cached_dwpose.det_filename != det_filename:
+ t = Wholebody(det_model_path, None, torchscript_device=torchscript_device)
+ t.pose = global_cached_dwpose.pose
+ t.pose_filename = global_cached_dwpose.pose
+ global_cached_dwpose = t
+
+ if global_cached_dwpose.pose is None or global_cached_dwpose.pose_filename != pose_filename:
+ t = Wholebody(None, pose_model_path, torchscript_device=torchscript_device)
+ t.det = global_cached_dwpose.det
+ t.det_filename = global_cached_dwpose.det_filename
+ global_cached_dwpose = t
+ return cls(global_cached_dwpose)
+
+ def detect_poses(self, oriImg) -> List[PoseResult]:
+ with torch.no_grad():
+ keypoints_info = self.dw_pose_estimation(oriImg.copy())
+ return Wholebody.format_result(keypoints_info)
+
+ def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", xinsr_stick_scaling=False, **kwargs):
+ if hand_and_face is not None:
+ warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
+ include_hand = hand_and_face
+ include_face = hand_and_face
+
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, _ = resize_image_with_pad(input_image, 0, upscale_method)
+ poses = self.detect_poses(input_image)
+
+ canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face, xinsr_stick_scaling=xinsr_stick_scaling)
+ canvas, remove_pad = resize_image_with_pad(canvas, detect_resolution, upscale_method)
+ detected_map = HWC3(remove_pad(canvas))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ if image_and_json:
+ return (detected_map, encode_poses_as_dict(poses, input_image.shape[0], input_image.shape[1]))
+
+ return detected_map
+
+global_cached_animalpose = AnimalPoseImage()
+class AnimalposeDetector:
+ """
+ A class for detecting animal poses in images using the RTMPose AP10k model.
+
+ Attributes:
+ model_dir (str): Path to the directory where the pose models are stored.
+ """
+ def __init__(self, animal_pose_estimation):
+ self.animal_pose_estimation = animal_pose_estimation
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename="yolox_l.onnx", pose_filename="dw-ll_ucoco_384.onnx", torchscript_device="cuda"):
+ global global_cached_animalpose
+ det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename)
+ pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename)
+
+ print(f"\nAnimalPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation")
+ if global_cached_animalpose.det is None or global_cached_animalpose.det_filename != det_filename:
+ t = AnimalPoseImage(det_model_path, None, torchscript_device=torchscript_device)
+ t.pose = global_cached_animalpose.pose
+ t.pose_filename = global_cached_animalpose.pose
+ global_cached_animalpose = t
+
+ if global_cached_animalpose.pose is None or global_cached_animalpose.pose_filename != pose_filename:
+ t = AnimalPoseImage(None, pose_model_path, torchscript_device=torchscript_device)
+ t.det = global_cached_animalpose.det
+ t.det_filename = global_cached_animalpose.det_filename
+ global_cached_animalpose = t
+ return cls(global_cached_animalpose)
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ result = self.animal_pose_estimation(input_image)
+ if result is None:
+ detected_map = np.zeros_like(input_image)
+ openpose_dict = {
+ 'version': 'ap10k',
+ 'animals': [],
+ 'canvas_height': input_image.shape[0],
+ 'canvas_width': input_image.shape[1]
+ }
+ else:
+ detected_map, openpose_dict = result
+ detected_map = remove_pad(detected_map)
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ if image_and_json:
+ return (detected_map, openpose_dict)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/dwpose/animalpose.py b/src/custom_controlnet_aux/dwpose/animalpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bc433c7d6e5172537fdd4a8f4e9dc8ee158627
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/animalpose.py
@@ -0,0 +1,271 @@
+import numpy as np
+import cv2
+import os
+import cv2
+from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox
+from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas
+from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose
+
+from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox
+from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose
+from typing import List, Optional
+from .types import PoseResult, BodyResult, Keypoint
+from .util import guess_onnx_input_shape_dtype, get_ort_providers, get_model_type, is_model_torchscript
+from timeit import default_timer
+import torch
+
+def drawBetweenKeypoints(pose_img, keypoints, indexes, color, scaleFactor):
+ ind0 = indexes[0] - 1
+ ind1 = indexes[1] - 1
+
+ point1 = (keypoints[ind0][0], keypoints[ind0][1])
+ point2 = (keypoints[ind1][0], keypoints[ind1][1])
+
+ thickness = int(5 // scaleFactor)
+
+
+ cv2.line(pose_img, (int(point1[0]), int(point1[1])), (int(point2[0]), int(point2[1])), color, thickness)
+
+
+def drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor):
+ for ind, keypointPair in enumerate(keypointPairsList):
+ drawBetweenKeypoints(pose_img, keypoints, keypointPair, colorsList[ind], scaleFactor)
+
+def drawBetweenSetofKeypointLists(pose_img, keypoints_set, keypointPairsList, colorsList, scaleFactor):
+ for keypoints in keypoints_set:
+ drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor)
+
+
+def padImg(img, size, blackBorder=True):
+ left, right, top, bottom = 0, 0, 0, 0
+
+ # pad x
+ if img.shape[1] < size[1]:
+ sidePadding = int((size[1] - img.shape[1]) // 2)
+ left = sidePadding
+ right = sidePadding
+
+ # pad extra on right if padding needed is an odd number
+ if img.shape[1] % 2 == 1:
+ right += 1
+
+ # pad y
+ if img.shape[0] < size[0]:
+ topBottomPadding = int((size[0] - img.shape[0]) // 2)
+ top = topBottomPadding
+ bottom = topBottomPadding
+
+ # pad extra on bottom if padding needed is an odd number
+ if img.shape[0] % 2 == 1:
+ bottom += 1
+
+ if blackBorder:
+ paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=(0,0,0))
+ else:
+ paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_REPLICATE)
+
+ return paddedImg
+
+def smartCrop(img, size, center):
+
+ width = img.shape[1]
+ height = img.shape[0]
+ xSize = size[1]
+ ySize = size[0]
+ xCenter = center[0]
+ yCenter = center[1]
+
+ if img.shape[0] > size[0] or img.shape[1] > size[1]:
+
+
+ leftMargin = xCenter - xSize//2
+ rightMargin = xCenter + xSize//2
+ upMargin = yCenter - ySize//2
+ downMargin = yCenter + ySize//2
+
+
+ if(leftMargin < 0):
+ xCenter += (-leftMargin)
+ if(rightMargin > width):
+ xCenter -= (rightMargin - width)
+
+ if(upMargin < 0):
+ yCenter -= -upMargin
+ if(downMargin > height):
+ yCenter -= (downMargin - height)
+
+
+ img = cv2.getRectSubPix(img, size, (xCenter, yCenter))
+
+
+
+ return img
+
+
+
+def calculateScaleFactor(img, size, poseSpanX, poseSpanY):
+
+ poseSpanX = max(poseSpanX, size[0])
+
+ scaleFactorX = 1
+
+
+ if poseSpanX > size[0]:
+ scaleFactorX = size[0] / poseSpanX
+
+ scaleFactorY = 1
+ if poseSpanY > size[1]:
+ scaleFactorY = size[1] / poseSpanY
+
+ scaleFactor = min(scaleFactorX, scaleFactorY)
+
+
+ return scaleFactor
+
+
+
+def scaleImg(img, size, poseSpanX, poseSpanY, scaleFactor):
+ scaledImg = img
+
+ scaledImg = cv2.resize(img, (0, 0), fx=scaleFactor, fy=scaleFactor)
+
+ return scaledImg, scaleFactor
+
+class AnimalPoseImage:
+ def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"):
+ self.det_filename = det_model_path and os.path.basename(det_model_path)
+ self.pose_filename = pose_model_path and os.path.basename(pose_model_path)
+ self.det, self.pose = None, None
+ # return type: None ort cv2 torchscript
+ self.det_model_type = get_model_type("AnimalPose",self.det_filename)
+ self.pose_model_type = get_model_type("AnimalPose",self.pose_filename)
+ # Always loads to CPU to avoid building OpenCV.
+ cv2_device = 'cpu'
+ cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA
+ # You need to manually build OpenCV through cmake to work with your GPU.
+ cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA
+ ort_providers = get_ort_providers()
+
+ if self.det_model_type is None:
+ pass
+ elif self.det_model_type == "ort":
+ try:
+ import onnxruntime as ort
+ self.det = ort.InferenceSession(det_model_path, providers=ort_providers)
+ except:
+ print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
+ elif self.det_model_type == "cv2":
+ try:
+ self.det = cv2.dnn.readNetFromONNX(det_model_path)
+ self.det.setPreferableBackend(cv2_backend)
+ self.det.setPreferableTarget(cv2_providers)
+ except:
+ print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider")
+ try:
+ import onnxruntime as ort
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
+ except:
+ print(f"Failed to load {det_model_path}, you can use other models instead")
+ else:
+ self.det = torch.jit.load(det_model_path)
+ self.det.to(torchscript_device)
+
+ if self.pose_model_type is None:
+ pass
+ elif self.pose_model_type == "ort":
+ try:
+ import onnxruntime as ort
+ self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers)
+ except:
+ print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
+ self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"])
+ elif self.pose_model_type == "cv2":
+ self.pose = cv2.dnn.readNetFromONNX(pose_model_path)
+ self.pose.setPreferableBackend(cv2_backend)
+ self.pose.setPreferableTarget(cv2_providers)
+ else:
+ self.pose = torch.jit.load(pose_model_path)
+ self.pose.to(torchscript_device)
+
+ if self.pose_filename is not None:
+ self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename)
+
+ def __call__(self, oriImg):
+ detect_classes = list(range(14, 23 + 1)) #https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml
+
+ #Sacrifice accurate time measurement for compatibility
+ det_start = default_timer()
+ if is_model_torchscript(self.det):
+ det_result = inference_jit_yolox(self.det, oriImg, detect_classes=detect_classes)
+ else:
+ det_start = default_timer()
+ det_onnx_dtype = np.float32 if "yolox" in self.det_filename else np.uint8
+ if "yolox" in self.det_filename:
+ det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
+ else:
+ #FP16 and INT8 YOLO NAS accept uint8 input
+ det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype)
+ print(f"AnimalPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms")
+
+ if (det_result is None) or (det_result.shape[0] == 0):
+ openpose_dict = {
+ 'version': 'ap10k',
+ 'animals': [],
+ 'canvas_height': oriImg.shape[0],
+ 'canvas_width': oriImg.shape[1]
+ }
+ return np.zeros_like(oriImg), openpose_dict
+
+ pose_start = default_timer()
+ if is_model_torchscript(self.pose):
+ keypoint_sets, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size)
+ else:
+ pose_start = default_timer()
+ _, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename)
+ keypoint_sets, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype)
+ print(f"AnimalPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n")
+
+ animal_kps_scores = []
+ pose_img = np.zeros((oriImg.shape[0], oriImg.shape[1], 3), dtype = np.uint8)
+ for (idx, keypoints) in enumerate(keypoint_sets):
+ # don't use keypoints that go outside the frame in calculations for the center
+ interorKeypoints = keypoints[((keypoints[:,0] > 0) & (keypoints[:,0] < oriImg.shape[1])) & ((keypoints[:,1] > 0) & (keypoints[:,1] < oriImg.shape[0]))]
+
+ xVals = interorKeypoints[:,0]
+ yVals = interorKeypoints[:,1]
+
+ minX = np.amin(xVals)
+ minY = np.amin(yVals)
+ maxX = np.amax(xVals)
+ maxY = np.amax(yVals)
+
+ poseSpanX = maxX - minX
+ poseSpanY = maxY - minY
+
+ # find mean center
+
+ xSum = np.sum(xVals)
+ ySum = np.sum(yVals)
+
+ xCenter = xSum // xVals.shape[0]
+ yCenter = ySum // yVals.shape[0]
+ center_of_keypoints = (xCenter,yCenter)
+
+ # order of the keypoints for AP10k and a standardized list of colors for limbs
+ keypointPairsList = [(1,2), (2,3), (1,3), (3,4), (4,9), (9,10), (10,11), (4,6), (6,7), (7,8), (4,5), (5,15), (15,16), (16,17), (5,12), (12,13), (13,14)]
+ colorsList = [(255,255,255), (100,255,100), (150,255,255), (100,50,255), (50,150,200), (0,255,255), (0,150,0), (0,0,255), (0,0,150), (255,50,255), (255,0,255), (255,0,0), (150,0,0), (255,255,100), (0,150,0), (255,255,0), (150,150,150)] # 16 colors needed
+
+ drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor=1.0)
+ score = scores[idx, ..., None]
+ score[score > 1.0] = 1.0
+ score[score < 0.0] = 0.0
+ animal_kps_scores.append(np.concatenate((keypoints, score), axis=-1))
+
+ openpose_dict = {
+ 'version': 'ap10k',
+ 'animals': [keypoints.tolist() for keypoints in animal_kps_scores],
+ 'canvas_height': oriImg.shape[0],
+ 'canvas_width': oriImg.shape[1]
+ }
+ return pose_img, openpose_dict
diff --git a/src/custom_controlnet_aux/dwpose/body.py b/src/custom_controlnet_aux/dwpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..32934f19eba4b7e762678fd1fcd6b2bd193811d6
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/body.py
@@ -0,0 +1,261 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+from typing import NamedTuple, List, Union
+
+from . import util
+from .model import bodypose_model
+from .types import Keypoint, BodyResult
+
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ data = data.to(self.cn_device)
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = util.smart_resize_k(paf, fx=stride, fy=stride)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+
+ @staticmethod
+ def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
+ """
+ Format the body results from the candidate and subset arrays into a list of BodyResult objects.
+
+ Args:
+ candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
+ for each body part.
+ subset (np.ndarray): An array of subsets containing indices to the candidate array for each
+ person detected. The last two columns of each row hold the total score and total parts
+ of the person.
+
+ Returns:
+ List[BodyResult]: A list of BodyResult objects, where each object represents a person with
+ detected keypoints, total score, and total parts.
+ """
+ return [
+ BodyResult(
+ keypoints=[
+ Keypoint(
+ x=candidate[candidate_index][0],
+ y=candidate[candidate_index][1],
+ score=candidate[candidate_index][2],
+ id=candidate[candidate_index][3]
+ ) if candidate_index != -1 else None
+ for candidate_index in person[:18].astype(int)
+ ],
+ total_score=person[18],
+ total_parts=person[19]
+ )
+ for person in subset
+ ]
+
+
+if __name__ == "__main__":
+ body_estimation = Body('../model/body_pose_model.pth')
+
+ test_image = '../images/ski.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ candidate, subset = body_estimation(oriImg)
+ bodies = body_estimation.format_body_result(candidate, subset)
+
+ canvas = oriImg
+ for body in bodies:
+ canvas = util.draw_bodypose(canvas, body)
+
+ plt.imshow(canvas[:, :, [2, 1, 0]])
+ plt.show()
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_onnx/__init__.py b/src/custom_controlnet_aux/dwpose/dw_onnx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_onnx/__init__.py
@@ -0,0 +1 @@
+#Dummy file ensuring this package will be recognized
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py
new file mode 100644
index 0000000000000000000000000000000000000000..0365234c2caef3b98fc01304ba5365da2115ba65
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py
@@ -0,0 +1,129 @@
+import cv2
+import numpy as np
+
+def nms(boxes, scores, nms_thr):
+ """Single class NMS implemented in Numpy."""
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= nms_thr)[0]
+ order = order[inds + 1]
+
+ return keep
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr):
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
+ final_dets = []
+ num_classes = scores.shape[1]
+ for cls_ind in range(num_classes):
+ cls_scores = scores[:, cls_ind]
+ valid_score_mask = cls_scores > score_thr
+ if valid_score_mask.sum() == 0:
+ continue
+ else:
+ valid_scores = cls_scores[valid_score_mask]
+ valid_boxes = boxes[valid_score_mask]
+ keep = nms(valid_boxes, valid_scores, nms_thr)
+ if len(keep) > 0:
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
+ dets = np.concatenate(
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+ )
+ final_dets.append(dets)
+ if len(final_dets) == 0:
+ return None
+ return np.concatenate(final_dets, 0)
+
+def demo_postprocess(outputs, img_size, p6=False):
+ grids = []
+ expanded_strides = []
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
+
+ hsizes = [img_size[0] // stride for stride in strides]
+ wsizes = [img_size[1] // stride for stride in strides]
+
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+ grids.append(grid)
+ shape = grid.shape[:2]
+ expanded_strides.append(np.full((*shape, 1), stride))
+
+ grids = np.concatenate(grids, 1)
+ expanded_strides = np.concatenate(expanded_strides, 1)
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+ return outputs
+
+def preprocess(img, input_size, swap=(2, 0, 1)):
+ if len(img.shape) == 3:
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+ else:
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+ resized_img = cv2.resize(
+ img,
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+ padded_img = padded_img.transpose(swap)
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+ return padded_img, r
+
+def inference_detector(session, oriImg, detect_classes=[0], dtype=np.float32):
+ input_shape = (640,640)
+ img, ratio = preprocess(oriImg, input_shape)
+
+ input = img[None, :, :, :]
+ input = input.astype(dtype)
+ if "InferenceSession" in type(session).__name__:
+ input_name = session.get_inputs()[0].name
+ output = session.run(None, {input_name: input})
+ else:
+ outNames = session.getUnconnectedOutLayersNames()
+ session.setInput(input)
+ output = session.forward(outNames)
+
+ predictions = demo_postprocess(output[0], input_shape)[0]
+
+ boxes = predictions[:, :4]
+ scores = predictions[:, 4:5] * predictions[:, 5:]
+
+ boxes_xyxy = np.ones_like(boxes)
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
+ boxes_xyxy /= ratio
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+ if dets is None:
+ return None
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
+ isscore = final_scores>0.3
+ iscat = np.isin(final_cls_inds, detect_classes)
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
+ final_boxes = final_boxes[isbbox]
+ return final_boxes
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..956c4bc715214bcc2e6228166032418294df46bc
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py
@@ -0,0 +1,363 @@
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+
+def preprocess(
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Do preprocessing for DWPose model inference.
+
+ Args:
+ img (np.ndarray): Input image in shape.
+ input_size (tuple): Input image size in shape (w, h).
+
+ Returns:
+ tuple:
+ - resized_img (np.ndarray): Preprocessed image.
+ - center (np.ndarray): Center of image.
+ - scale (np.ndarray): Scale of image.
+ """
+ # get shape of image
+ img_shape = img.shape[:2]
+ out_img, out_center, out_scale = [], [], []
+ if len(out_bbox) == 0:
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
+ for i in range(len(out_bbox)):
+ x0 = out_bbox[i][0]
+ y0 = out_bbox[i][1]
+ x1 = out_bbox[i][2]
+ y1 = out_bbox[i][3]
+ bbox = np.array([x0, y0, x1, y1])
+
+ # get center and scale
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
+
+ # do affine transformation
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
+
+ # normalize image
+ mean = np.array([123.675, 116.28, 103.53])
+ std = np.array([58.395, 57.12, 57.375])
+ resized_img = (resized_img - mean) / std
+
+ out_img.append(resized_img)
+ out_center.append(center)
+ out_scale.append(scale)
+
+ return out_img, out_center, out_scale
+
+
+def inference(sess, img, dtype=np.float32):
+ """Inference DWPose model. Processing all image segments at once to take advantage of GPU's parallelism ability if onnxruntime is installed
+
+ Args:
+ sess : ONNXRuntime session.
+ img : Input image in shape.
+
+ Returns:
+ outputs : Output of DWPose model.
+ """
+ all_out = []
+ # build input
+ input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
+ input = input.astype(dtype)
+ if "InferenceSession" in type(sess).__name__:
+ input_name = sess.get_inputs()[0].name
+ all_outputs = sess.run(None, {input_name: input})
+ for batch_idx in range(len(all_outputs[0])):
+ outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
+ all_out.append(outputs)
+ return all_out
+
+ #OpenCV doesn't support batch processing sadly
+ for i in range(len(img)):
+ input = img[i].transpose(2, 0, 1)
+ input = input[None, :, :, :]
+
+ outNames = sess.getUnconnectedOutLayersNames()
+ sess.setInput(input)
+ outputs = sess.forward(outNames)
+ all_out.append(outputs)
+
+ return all_out
+
+def postprocess(outputs: List[np.ndarray],
+ model_input_size: Tuple[int, int],
+ center: Tuple[int, int],
+ scale: Tuple[int, int],
+ simcc_split_ratio: float = 2.0
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Postprocess for DWPose model output.
+
+ Args:
+ outputs (np.ndarray): Output of RTMPose model.
+ model_input_size (tuple): RTMPose model Input image size.
+ center (tuple): Center of bbox in shape (x, y).
+ scale (tuple): Scale of bbox in shape (w, h).
+ simcc_split_ratio (float): Split ratio of simcc.
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Rescaled keypoints.
+ - scores (np.ndarray): Model predict scores.
+ """
+ all_key = []
+ all_score = []
+ for i in range(len(outputs)):
+ # use simcc to decode
+ simcc_x, simcc_y = outputs[i]
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
+
+ # rescale keypoints
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
+ all_key.append(keypoints[0])
+ all_score.append(scores[0])
+
+ return np.array(all_key), np.array(all_score)
+
+
+def bbox_xyxy2cs(bbox: np.ndarray,
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
+
+ Args:
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
+ as (left, top, right, bottom)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
+ (n, 2)
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
+ (n, 2)
+ """
+ # convert single bbox from (4, ) to (1, 4)
+ dim = bbox.ndim
+ if dim == 1:
+ bbox = bbox[None, :]
+
+ # get bbox center and scale
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
+
+ if dim == 1:
+ center = center[0]
+ scale = scale[0]
+
+ return center, scale
+
+
+def _fix_aspect_ratio(bbox_scale: np.ndarray,
+ aspect_ratio: float) -> np.ndarray:
+ """Extend the scale to match the given aspect ratio.
+
+ Args:
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
+ aspect_ratio (float): The ratio of ``w/h``
+
+ Returns:
+ np.ndarray: The reshaped image scale in (2, )
+ """
+ w, h = np.hsplit(bbox_scale, [1])
+ bbox_scale = np.where(w > h * aspect_ratio,
+ np.hstack([w, w / aspect_ratio]),
+ np.hstack([h * aspect_ratio, h]))
+ return bbox_scale
+
+
+def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
+ """Rotate a point by an angle.
+
+ Args:
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
+ angle_rad (float): rotation angle in radian
+
+ Returns:
+ np.ndarray: Rotated point in shape (2, )
+ """
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
+ return rot_mat @ pt
+
+
+def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ direction = a - b
+ c = b + np.r_[-direction[1], direction[0]]
+ return c
+
+
+def get_warp_matrix(center: np.ndarray,
+ scale: np.ndarray,
+ rot: float,
+ output_size: Tuple[int, int],
+ shift: Tuple[float, float] = (0., 0.),
+ inv: bool = False) -> np.ndarray:
+ """Calculate the affine transformation matrix that can warp the bbox area
+ in the input image to the output size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
+ destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: A 2x3 transformation matrix
+ """
+ shift = np.array(shift)
+ src_w = scale[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ # compute transformation matrix
+ rot_rad = np.deg2rad(rot)
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ # get four corners of the src rectangle in the original image
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale * shift
+ src[1, :] = center + src_dir + scale * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ # get four corners of the dst rectangle in the input image
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return warp_mat
+
+
+def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the bbox image as the model input by affine transform.
+
+ Args:
+ input_size (dict): The input size of the model.
+ bbox_scale (dict): The bbox scale of the img.
+ bbox_center (dict): The bbox center of the img.
+ img (np.ndarray): The original image.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: img after affine transform.
+ - np.ndarray[float32]: bbox scale after affine transform.
+ """
+ w, h = input_size
+ warp_size = (int(w), int(h))
+
+ # reshape bbox to fixed aspect ratio
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
+
+ # get the affine matrix
+ center = bbox_center
+ scale = bbox_scale
+ rot = 0
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
+
+ # do affine transform
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
+
+ return img, bbox_scale
+
+
+def get_simcc_maximum(simcc_x: np.ndarray,
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from simcc representations.
+
+ Note:
+ instance number: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (N, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (N, K)
+ """
+ N, K, Wx = simcc_x.shape
+ simcc_x = simcc_x.reshape(N * K, -1)
+ simcc_y = simcc_y.reshape(N * K, -1)
+
+ # get maximum value locations
+ x_locs = np.argmax(simcc_x, axis=1)
+ y_locs = np.argmax(simcc_y, axis=1)
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ max_val_x = np.amax(simcc_x, axis=1)
+ max_val_y = np.amax(simcc_y, axis=1)
+
+ # get maximum value across x and y axis
+ mask = max_val_x > max_val_y
+ max_val_x[mask] = max_val_y[mask]
+ vals = max_val_x
+ locs[vals <= 0.] = -1
+
+ # reshape
+ locs = locs.reshape(N, K, 2)
+ vals = vals.reshape(N, K)
+
+ return locs, vals
+
+
+def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
+ """Modulate simcc distribution with Gaussian.
+
+ Args:
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
+ simcc_split_ratio (int): The split ratio of simcc.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
+ """
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
+ keypoints /= simcc_split_ratio
+
+ return keypoints, scores
+
+
+def inference_pose(session, out_bbox, oriImg, model_input_size=(288, 384), dtype=np.float32):
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
+ outputs = inference(session, resized_img, dtype)
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
+
+ return keypoints, scores
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py
new file mode 100644
index 0000000000000000000000000000000000000000..67ff249be283b11e0eb7d95ef7c0adc024c48285
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py
@@ -0,0 +1,60 @@
+# Source: https://github.com/Hyuto/yolo-nas-onnx/tree/master/yolo-nas-py
+# Inspired from: https://github.com/Deci-AI/super-gradients/blob/3.1.1/src/super_gradients/training/processing/processing.py
+
+import numpy as np
+import cv2
+
+def preprocess(img, input_size, swap=(2, 0, 1)):
+ if len(img.shape) == 3:
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+ else:
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+ resized_img = cv2.resize(
+ img,
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+ padded_img = padded_img.transpose(swap)
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+ return padded_img, r
+
+def inference_detector(session, oriImg, detect_classes=[0], dtype=np.uint8):
+ """
+ This function is only compatible with onnx models exported from the new API with built-in NMS
+ ```py
+ from super_gradients.conversion.conversion_enums import ExportQuantizationMode
+ from super_gradients.common.object_names import Models
+ from super_gradients.training import models
+
+ model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")
+
+ export_result = model.export(
+ "yolo_nas/yolo_nas_l_fp16.onnx",
+ quantization_mode=ExportQuantizationMode.FP16,
+ device="cuda"
+ )
+ ```
+ """
+ input_shape = (640,640)
+ img, ratio = preprocess(oriImg, input_shape)
+ input = img[None, :, :, :]
+ input = input.astype(dtype)
+ if "InferenceSession" in type(session).__name__:
+ input_name = session.get_inputs()[0].name
+ output = session.run(None, {input_name: input})
+ else:
+ outNames = session.getUnconnectedOutLayersNames()
+ session.setInput(input)
+ output = session.forward(outNames)
+ num_preds, pred_boxes, pred_scores, pred_classes = output
+ num_preds = num_preds[0,0]
+ if num_preds == 0:
+ return None
+ idxs = np.where((np.isin(pred_classes[0, :num_preds], detect_classes)) & (pred_scores[0, :num_preds] > 0.3))
+ if (len(idxs) == 0) or (idxs[0].size == 0):
+ return None
+ return pred_boxes[0, idxs].squeeze(axis=0) / ratio
diff --git a/src/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py b/src/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py
@@ -0,0 +1 @@
+#Dummy file ensuring this package will be recognized
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py b/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae863509df37386365791df8b4a7635c05d9344
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py
@@ -0,0 +1,125 @@
+import cv2
+import numpy as np
+import torch
+
+def nms(boxes, scores, nms_thr):
+ """Single class NMS implemented in Numpy."""
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= nms_thr)[0]
+ order = order[inds + 1]
+
+ return keep
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr):
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
+ final_dets = []
+ num_classes = scores.shape[1]
+ for cls_ind in range(num_classes):
+ cls_scores = scores[:, cls_ind]
+ valid_score_mask = cls_scores > score_thr
+ if valid_score_mask.sum() == 0:
+ continue
+ else:
+ valid_scores = cls_scores[valid_score_mask]
+ valid_boxes = boxes[valid_score_mask]
+ keep = nms(valid_boxes, valid_scores, nms_thr)
+ if len(keep) > 0:
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
+ dets = np.concatenate(
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+ )
+ final_dets.append(dets)
+ if len(final_dets) == 0:
+ return None
+ return np.concatenate(final_dets, 0)
+
+def demo_postprocess(outputs, img_size, p6=False):
+ grids = []
+ expanded_strides = []
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
+
+ hsizes = [img_size[0] // stride for stride in strides]
+ wsizes = [img_size[1] // stride for stride in strides]
+
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+ grids.append(grid)
+ shape = grid.shape[:2]
+ expanded_strides.append(np.full((*shape, 1), stride))
+
+ grids = np.concatenate(grids, 1)
+ expanded_strides = np.concatenate(expanded_strides, 1)
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+ return outputs
+
+def preprocess(img, input_size, swap=(2, 0, 1)):
+ if len(img.shape) == 3:
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+ else:
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+ resized_img = cv2.resize(
+ img,
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+ padded_img = padded_img.transpose(swap)
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+ return padded_img, r
+
+def inference_detector(model, oriImg, detect_classes=[0]):
+ input_shape = (640,640)
+ img, ratio = preprocess(oriImg, input_shape)
+
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ input = img[None, :, :, :]
+ input = torch.from_numpy(input).to(device, dtype)
+
+ output = model(input).float().cpu().detach().numpy()
+ predictions = demo_postprocess(output[0], input_shape)
+
+ boxes = predictions[:, :4]
+ scores = predictions[:, 4:5] * predictions[:, 5:]
+
+ boxes_xyxy = np.ones_like(boxes)
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
+ boxes_xyxy /= ratio
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+ if dets is None:
+ return None
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
+ isscore = final_scores>0.3
+ iscat = np.isin(final_cls_inds, detect_classes)
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
+ final_boxes = final_boxes[isbbox]
+ return final_boxes
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py b/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a7b6acee1cd39657cec719aef1aaa09648da4e2
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py
@@ -0,0 +1,363 @@
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+import torch
+
+def preprocess(
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Do preprocessing for DWPose model inference.
+
+ Args:
+ img (np.ndarray): Input image in shape.
+ input_size (tuple): Input image size in shape (w, h).
+
+ Returns:
+ tuple:
+ - resized_img (np.ndarray): Preprocessed image.
+ - center (np.ndarray): Center of image.
+ - scale (np.ndarray): Scale of image.
+ """
+ # get shape of image
+ img_shape = img.shape[:2]
+ out_img, out_center, out_scale = [], [], []
+ if len(out_bbox) == 0:
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
+ for i in range(len(out_bbox)):
+ x0 = out_bbox[i][0]
+ y0 = out_bbox[i][1]
+ x1 = out_bbox[i][2]
+ y1 = out_bbox[i][3]
+ bbox = np.array([x0, y0, x1, y1])
+
+ # get center and scale
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
+
+ # do affine transformation
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
+
+ # normalize image
+ mean = np.array([123.675, 116.28, 103.53])
+ std = np.array([58.395, 57.12, 57.375])
+ resized_img = (resized_img - mean) / std
+
+ out_img.append(resized_img)
+ out_center.append(center)
+ out_scale.append(scale)
+
+ return out_img, out_center, out_scale
+
+def inference(model, img, bs=5):
+ """Inference DWPose model implemented in TorchScript.
+
+ Args:
+ model : TorchScript Model.
+ img : Input image in shape.
+
+ Returns:
+ outputs : Output of DWPose model.
+ """
+ all_out = []
+ # build input
+ orig_img_count = len(img)
+ #Pad zeros to fit batch size
+ for _ in range(bs - (orig_img_count % bs)):
+ img.append(np.zeros_like(img[0]))
+ input = np.stack(img, axis=0).transpose(0, 3, 1, 2)
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ input = torch.from_numpy(input).to(device, dtype)
+
+ out1, out2 = [], []
+ for i in range(input.shape[0] // bs):
+ curr_batch_output = model(input[i*bs:(i+1)*bs])
+ out1.append(curr_batch_output[0].float())
+ out2.append(curr_batch_output[1].float())
+ out1, out2 = torch.cat(out1, dim=0)[:orig_img_count], torch.cat(out2, dim=0)[:orig_img_count]
+ out1, out2 = out1.float().cpu().detach().numpy(), out2.float().cpu().detach().numpy()
+ all_outputs = out1, out2
+
+ for batch_idx in range(len(all_outputs[0])):
+ outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))]
+ all_out.append(outputs)
+ return all_out
+def postprocess(outputs: List[np.ndarray],
+ model_input_size: Tuple[int, int],
+ center: Tuple[int, int],
+ scale: Tuple[int, int],
+ simcc_split_ratio: float = 2.0
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Postprocess for DWPose model output.
+
+ Args:
+ outputs (np.ndarray): Output of RTMPose model.
+ model_input_size (tuple): RTMPose model Input image size.
+ center (tuple): Center of bbox in shape (x, y).
+ scale (tuple): Scale of bbox in shape (w, h).
+ simcc_split_ratio (float): Split ratio of simcc.
+
+ Returns:
+ tuple:
+ - keypoints (np.ndarray): Rescaled keypoints.
+ - scores (np.ndarray): Model predict scores.
+ """
+ all_key = []
+ all_score = []
+ for i in range(len(outputs)):
+ # use simcc to decode
+ simcc_x, simcc_y = outputs[i]
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
+
+ # rescale keypoints
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
+ all_key.append(keypoints[0])
+ all_score.append(scores[0])
+
+ return np.array(all_key), np.array(all_score)
+
+
+def bbox_xyxy2cs(bbox: np.ndarray,
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
+
+ Args:
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
+ as (left, top, right, bottom)
+ padding (float): BBox padding factor that will be multilied to scale.
+ Default: 1.0
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
+ (n, 2)
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
+ (n, 2)
+ """
+ # convert single bbox from (4, ) to (1, 4)
+ dim = bbox.ndim
+ if dim == 1:
+ bbox = bbox[None, :]
+
+ # get bbox center and scale
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
+
+ if dim == 1:
+ center = center[0]
+ scale = scale[0]
+
+ return center, scale
+
+
+def _fix_aspect_ratio(bbox_scale: np.ndarray,
+ aspect_ratio: float) -> np.ndarray:
+ """Extend the scale to match the given aspect ratio.
+
+ Args:
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
+ aspect_ratio (float): The ratio of ``w/h``
+
+ Returns:
+ np.ndarray: The reshaped image scale in (2, )
+ """
+ w, h = np.hsplit(bbox_scale, [1])
+ bbox_scale = np.where(w > h * aspect_ratio,
+ np.hstack([w, w / aspect_ratio]),
+ np.hstack([h * aspect_ratio, h]))
+ return bbox_scale
+
+
+def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
+ """Rotate a point by an angle.
+
+ Args:
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
+ angle_rad (float): rotation angle in radian
+
+ Returns:
+ np.ndarray: Rotated point in shape (2, )
+ """
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
+ return rot_mat @ pt
+
+
+def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ direction = a - b
+ c = b + np.r_[-direction[1], direction[0]]
+ return c
+
+
+def get_warp_matrix(center: np.ndarray,
+ scale: np.ndarray,
+ rot: float,
+ output_size: Tuple[int, int],
+ shift: Tuple[float, float] = (0., 0.),
+ inv: bool = False) -> np.ndarray:
+ """Calculate the affine transformation matrix that can warp the bbox area
+ in the input image to the output size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
+ destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: A 2x3 transformation matrix
+ """
+ shift = np.array(shift)
+ src_w = scale[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ # compute transformation matrix
+ rot_rad = np.deg2rad(rot)
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ # get four corners of the src rectangle in the original image
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale * shift
+ src[1, :] = center + src_dir + scale * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ # get four corners of the dst rectangle in the input image
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return warp_mat
+
+
+def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the bbox image as the model input by affine transform.
+
+ Args:
+ input_size (dict): The input size of the model.
+ bbox_scale (dict): The bbox scale of the img.
+ bbox_center (dict): The bbox center of the img.
+ img (np.ndarray): The original image.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: img after affine transform.
+ - np.ndarray[float32]: bbox scale after affine transform.
+ """
+ w, h = input_size
+ warp_size = (int(w), int(h))
+
+ # reshape bbox to fixed aspect ratio
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
+
+ # get the affine matrix
+ center = bbox_center
+ scale = bbox_scale
+ rot = 0
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
+
+ # do affine transform
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
+
+ return img, bbox_scale
+
+
+def get_simcc_maximum(simcc_x: np.ndarray,
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from simcc representations.
+
+ Note:
+ instance number: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (N, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (N, K)
+ """
+ N, K, Wx = simcc_x.shape
+ simcc_x = simcc_x.reshape(N * K, -1)
+ simcc_y = simcc_y.reshape(N * K, -1)
+
+ # get maximum value locations
+ x_locs = np.argmax(simcc_x, axis=1)
+ y_locs = np.argmax(simcc_y, axis=1)
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ max_val_x = np.amax(simcc_x, axis=1)
+ max_val_y = np.amax(simcc_y, axis=1)
+
+ # get maximum value across x and y axis
+ mask = max_val_x > max_val_y
+ max_val_x[mask] = max_val_y[mask]
+ vals = max_val_x
+ locs[vals <= 0.] = -1
+
+ # reshape
+ locs = locs.reshape(N, K, 2)
+ vals = vals.reshape(N, K)
+
+ return locs, vals
+
+
+def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
+ """Modulate simcc distribution with Gaussian.
+
+ Args:
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
+ simcc_split_ratio (int): The split ratio of simcc.
+
+ Returns:
+ tuple: A tuple containing center and scale.
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
+ """
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
+ keypoints /= simcc_split_ratio
+
+ return keypoints, scores
+
+def inference_pose(model, out_bbox, oriImg, model_input_size=(288, 384)):
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
+ #outputs = inference(session, resized_img, dtype)
+ outputs = inference(model, resized_img)
+
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
+
+ return keypoints, scores
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/face.py b/src/custom_controlnet_aux/dwpose/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3c46d77664aa9fa91c63785a1485a396f05cacc
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/face.py
@@ -0,0 +1,362 @@
+import logging
+import numpy as np
+from torchvision.transforms import ToTensor, ToPILImage
+import torch
+import torch.nn.functional as F
+import cv2
+
+from . import util
+from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init
+
+
+class FaceNet(Module):
+ """Model the cascading heatmaps. """
+ def __init__(self):
+ super(FaceNet, self).__init__()
+ # cnn to make feature map
+ self.relu = ReLU()
+ self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
+ self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
+ kernel_size=3, stride=1, padding=1)
+ self.conv1_2 = Conv2d(
+ in_channels=64, out_channels=64, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_1 = Conv2d(
+ in_channels=64, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_1 = Conv2d(
+ in_channels=128, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_2 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_3 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_4 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_1 = Conv2d(
+ in_channels=256, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_3 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_4 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_1 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_3_CPM = Conv2d(
+ in_channels=512, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+
+ # stage1
+ self.conv6_1_CPM = Conv2d(
+ in_channels=128, out_channels=512, kernel_size=1, stride=1,
+ padding=0)
+ self.conv6_2_CPM = Conv2d(
+ in_channels=512, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage2
+ self.Mconv1_stage2 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage2 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage3
+ self.Mconv1_stage3 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage3 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage4
+ self.Mconv1_stage4 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage4 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage5
+ self.Mconv1_stage5 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage5 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage6
+ self.Mconv1_stage6 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage6 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ for m in self.modules():
+ if isinstance(m, Conv2d):
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ """Return a list of heatmaps."""
+ heatmaps = []
+
+ h = self.relu(self.conv1_1(x))
+ h = self.relu(self.conv1_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv2_1(h))
+ h = self.relu(self.conv2_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv3_1(h))
+ h = self.relu(self.conv3_2(h))
+ h = self.relu(self.conv3_3(h))
+ h = self.relu(self.conv3_4(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv4_1(h))
+ h = self.relu(self.conv4_2(h))
+ h = self.relu(self.conv4_3(h))
+ h = self.relu(self.conv4_4(h))
+ h = self.relu(self.conv5_1(h))
+ h = self.relu(self.conv5_2(h))
+ h = self.relu(self.conv5_3_CPM(h))
+ feature_map = h
+
+ # stage1
+ h = self.relu(self.conv6_1_CPM(h))
+ h = self.conv6_2_CPM(h)
+ heatmaps.append(h)
+
+ # stage2
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage2(h))
+ h = self.relu(self.Mconv2_stage2(h))
+ h = self.relu(self.Mconv3_stage2(h))
+ h = self.relu(self.Mconv4_stage2(h))
+ h = self.relu(self.Mconv5_stage2(h))
+ h = self.relu(self.Mconv6_stage2(h))
+ h = self.Mconv7_stage2(h)
+ heatmaps.append(h)
+
+ # stage3
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage3(h))
+ h = self.relu(self.Mconv2_stage3(h))
+ h = self.relu(self.Mconv3_stage3(h))
+ h = self.relu(self.Mconv4_stage3(h))
+ h = self.relu(self.Mconv5_stage3(h))
+ h = self.relu(self.Mconv6_stage3(h))
+ h = self.Mconv7_stage3(h)
+ heatmaps.append(h)
+
+ # stage4
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage4(h))
+ h = self.relu(self.Mconv2_stage4(h))
+ h = self.relu(self.Mconv3_stage4(h))
+ h = self.relu(self.Mconv4_stage4(h))
+ h = self.relu(self.Mconv5_stage4(h))
+ h = self.relu(self.Mconv6_stage4(h))
+ h = self.Mconv7_stage4(h)
+ heatmaps.append(h)
+
+ # stage5
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage5(h))
+ h = self.relu(self.Mconv2_stage5(h))
+ h = self.relu(self.Mconv3_stage5(h))
+ h = self.relu(self.Mconv4_stage5(h))
+ h = self.relu(self.Mconv5_stage5(h))
+ h = self.relu(self.Mconv6_stage5(h))
+ h = self.Mconv7_stage5(h)
+ heatmaps.append(h)
+
+ # stage6
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage6(h))
+ h = self.relu(self.Mconv2_stage6(h))
+ h = self.relu(self.Mconv3_stage6(h))
+ h = self.relu(self.Mconv4_stage6(h))
+ h = self.relu(self.Mconv5_stage6(h))
+ h = self.relu(self.Mconv6_stage6(h))
+ h = self.Mconv7_stage6(h)
+ heatmaps.append(h)
+
+ return heatmaps
+
+
+LOG = logging.getLogger(__name__)
+TOTEN = ToTensor()
+TOPIL = ToPILImage()
+
+
+params = {
+ 'gaussian_sigma': 2.5,
+ 'inference_img_size': 736, # 368, 736, 1312
+ 'heatmap_peak_thresh': 0.1,
+ 'crop_scale': 1.5,
+ 'line_indices': [
+ [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
+ [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
+ [13, 14], [14, 15], [15, 16],
+ [17, 18], [18, 19], [19, 20], [20, 21],
+ [22, 23], [23, 24], [24, 25], [25, 26],
+ [27, 28], [28, 29], [29, 30],
+ [31, 32], [32, 33], [33, 34], [34, 35],
+ [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
+ [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
+ [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
+ [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
+ [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
+ [66, 67], [67, 60]
+ ],
+}
+
+
+class Face(object):
+ """
+ The OpenPose face landmark detector model.
+
+ Args:
+ inference_size: set the size of the inference image size, suggested:
+ 368, 736, 1312, default 736
+ gaussian_sigma: blur the heatmaps, default 2.5
+ heatmap_peak_thresh: return landmark if over threshold, default 0.1
+
+ """
+ def __init__(self, face_model_path,
+ inference_size=None,
+ gaussian_sigma=None,
+ heatmap_peak_thresh=None):
+ self.inference_size = inference_size or params["inference_img_size"]
+ self.sigma = gaussian_sigma or params['gaussian_sigma']
+ self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
+ self.model = FaceNet()
+ self.model.load_state_dict(torch.load(face_model_path))
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ self.model.eval()
+
+ def __call__(self, face_img):
+ H, W, C = face_img.shape
+
+ w_size = 384
+ x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
+
+ x_data = x_data.to(self.cn_device)
+
+ with torch.no_grad():
+ hs = self.model(x_data[None, ...])
+ heatmaps = F.interpolate(
+ hs[-1],
+ (H, W),
+ mode='bilinear', align_corners=True).cpu().numpy()[0]
+ return heatmaps
+
+ def compute_peaks_from_heatmaps(self, heatmaps):
+ all_peaks = []
+ for part in range(heatmaps.shape[0]):
+ map_ori = heatmaps[part].copy()
+ binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ continue
+
+ positions = np.where(binary > 0.5)
+ intensities = map_ori[positions]
+ mi = np.argmax(intensities)
+ y, x = positions[0][mi], positions[1][mi]
+ all_peaks.append([x, y])
+
+ return np.array(all_peaks)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/hand.py b/src/custom_controlnet_aux/dwpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..74767def506c72612954fe3b79056d17a83b1e16
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/hand.py
@@ -0,0 +1,94 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+
+from .model import handpose_model
+from . import util
+
+class Hand(object):
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImgRaw):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize for x in scale_search]
+
+ wsize = 128
+ heatmap_avg = np.zeros((wsize, wsize, 22))
+
+ Hr, Wr, Cr = oriImgRaw.shape
+
+ oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize(oriImg, (scale, scale))
+
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+
+ with torch.no_grad():
+ data = data.to(self.cn_device)
+ output = self.model(data).cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (wsize, wsize))
+
+ heatmap_avg += heatmap / len(multiplier)
+
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+
+ y, x = util.npmax(map_ori)
+ y = int(float(y) * float(Hr) / float(wsize))
+ x = int(float(x) * float(Wr) / float(wsize))
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
+
+if __name__ == "__main__":
+ hand_estimation = Hand('../model/hand_pose_model.pth')
+
+ # test_image = '../images/hand.jpg'
+ test_image = '../images/hand.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ peaks = hand_estimation(oriImg)
+ canvas = util.draw_handpose(oriImg, peaks, True)
+ cv2.imshow('', canvas)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/dwpose/model.py b/src/custom_controlnet_aux/dwpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72dc79ad857933a7c108d21494d6395572b816e6
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/model.py
@@ -0,0 +1,218 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+
+
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
+
diff --git a/src/custom_controlnet_aux/dwpose/types.py b/src/custom_controlnet_aux/dwpose/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a5fa889e401bcd85e50977a674ce601bb02a46a
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/types.py
@@ -0,0 +1,30 @@
+from typing import NamedTuple, List, Optional
+
+class Keypoint(NamedTuple):
+ x: float
+ y: float
+ score: float = 1.0
+ id: int = -1
+
+
+class BodyResult(NamedTuple):
+ # Note: Using `Optional` instead of `|` operator as the ladder is a Python
+ # 3.10 feature.
+ # Annotator code should be Python 3.8 Compatible, as controlnet repo uses
+ # Python 3.8 environment.
+ # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
+ keypoints: List[Optional[Keypoint]]
+ total_score: float = 0.0
+ total_parts: int = 0
+
+
+HandResult = List[Keypoint]
+FaceResult = List[Keypoint]
+AnimalPoseResult = List[Keypoint]
+
+
+class PoseResult(NamedTuple):
+ body: BodyResult
+ left_hand: Optional[HandResult]
+ right_hand: Optional[HandResult]
+ face: Optional[FaceResult]
diff --git a/src/custom_controlnet_aux/dwpose/util.py b/src/custom_controlnet_aux/dwpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..42874b33bd69bf12f0ff0c0b8ff09cd512349d5f
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/util.py
@@ -0,0 +1,466 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+import os
+from typing import List, Tuple, Union, Optional
+
+from .body import BodyResult, Keypoint
+
+eps = 0.01
+
+
+def smart_resize(x, s):
+ Ht, Wt = s
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
+
+
+def smart_resize_k(x, fx, fy):
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ Ht, Wt = Ho * fy, Wo * fx
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+
+def is_normalized(keypoints: List[Optional[Keypoint]]) -> bool:
+ point_normalized = [
+ 0 <= abs(k.x) <= 1 and 0 <= abs(k.y) <= 1
+ for k in keypoints
+ if k is not None
+ ]
+ if not point_normalized:
+ return False
+ return all(point_normalized)
+
+
+def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint], xinsr_stick_scaling: bool = False) -> np.ndarray:
+ """
+ Draw keypoints and limbs representing body pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
+ keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
+ xinsr_stick_scaling (bool): Whether or not scaling stick width for xinsr ControlNet
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not is_normalized(keypoints):
+ H, W = 1.0, 1.0
+ else:
+ H, W, _ = canvas.shape
+
+ CH, CW, _ = canvas.shape
+ stickwidth = 4
+
+ # Ref: https://huggingface.co/xinsir/controlnet-openpose-sdxl-1.0
+ max_side = max(CW, CH)
+ if xinsr_stick_scaling:
+ stick_scale = 1 if max_side < 500 else min(2 + (max_side // 1000), 7)
+ else:
+ stick_scale = 1
+
+ limbSeq = [
+ [2, 3], [2, 6], [3, 4], [4, 5],
+ [6, 7], [7, 8], [2, 9], [9, 10],
+ [10, 11], [2, 12], [12, 13], [13, 14],
+ [2, 1], [1, 15], [15, 17], [1, 16],
+ [16, 18],
+ ]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ for (k1_index, k2_index), color in zip(limbSeq, colors):
+ keypoint1 = keypoints[k1_index - 1]
+ keypoint2 = keypoints[k2_index - 1]
+
+ if keypoint1 is None or keypoint2 is None:
+ continue
+
+ Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
+ X = np.array([keypoint1.y, keypoint2.y]) * float(H)
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth*stick_scale), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
+
+ for keypoint, color in zip(keypoints, colors):
+ if keypoint is None:
+ continue
+
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
+
+ return canvas
+
+
+def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints and connections representing hand pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ if not is_normalized(keypoints):
+ H, W = 1.0, 1.0
+ else:
+ H, W, _ = canvas.shape
+
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for ie, (e1, e2) in enumerate(edges):
+ k1 = keypoints[e1]
+ k2 = keypoints[e2]
+ if k1 is None or k2 is None:
+ continue
+
+ x1 = int(k1.x * W)
+ y1 = int(k1.y * H)
+ x2 = int(k2.x * W)
+ y2 = int(k2.y * H)
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
+
+ for keypoint in keypoints:
+ if keypoint is None:
+ continue
+
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ return canvas
+
+
+def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints representing face pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ if not is_normalized(keypoints):
+ H, W = 1.0, 1.0
+ else:
+ H, W, _ = canvas.shape
+
+ for keypoint in keypoints:
+ if keypoint is None:
+ continue
+
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
+ return canvas
+
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
+ """
+ Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
+ corner of the bounding box, the width (height) of the bounding box, and
+ a boolean flag indicating whether the hand is a left hand (True) or a
+ right hand (False).
+
+ Notes:
+ - The width and height of the bounding boxes are equal since the network requires squared input.
+ - The minimum bounding box size is 20 pixels.
+ """
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ left_shoulder = keypoints[5]
+ left_elbow = keypoints[6]
+ left_wrist = keypoints[7]
+ right_shoulder = keypoints[2]
+ right_elbow = keypoints[3]
+ right_wrist = keypoints[4]
+
+ # if any of three not detected
+ has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
+ has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
+ if not (has_left or has_right):
+ return []
+
+ hands = []
+ #left hand
+ if has_left:
+ hands.append([
+ left_shoulder.x, left_shoulder.y,
+ left_elbow.x, left_elbow.y,
+ left_wrist.x, left_wrist.y,
+ True
+ ])
+ # right hand
+ if has_right:
+ hands.append([
+ right_shoulder.x, right_shoulder.y,
+ right_elbow.x, right_elbow.y,
+ right_wrist.x, right_wrist.y,
+ False
+ ])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append((int(x), int(y), int(width), is_left))
+
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+
+# Written by Lvmin
+def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
+ """
+ Detect the face in the input body pose keypoints and calculate the bounding box for the face.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
+ bounding box and the width (height) of the bounding box, or None if the
+ face is not detected or the bounding box width is less than 20 pixels.
+
+ Notes:
+ - The width and height of the bounding box are equal.
+ - The minimum bounding box size is 20 pixels.
+ """
+ # left right eye ear 14 15 16 17
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ head = keypoints[0]
+ left_eye = keypoints[14]
+ right_eye = keypoints[15]
+ left_ear = keypoints[16]
+ right_ear = keypoints[17]
+
+ if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
+ return None
+
+ width = 0.0
+ x0, y0 = head.x, head.y
+
+ if left_eye is not None:
+ x1, y1 = left_eye.x, left_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if right_eye is not None:
+ x1, y1 = right_eye.x, right_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if left_ear is not None:
+ x1, y1 = left_ear.x, left_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ if right_ear is not None:
+ x1, y1 = right_ear.x, right_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ x, y = x0, y0
+
+ x -= width
+ y -= width
+
+ if x < 0:
+ x = 0
+
+ if y < 0:
+ y = 0
+
+ width1 = width * 2
+ width2 = width * 2
+
+ if x + width > image_width:
+ width1 = image_width - x
+
+ if y + width > image_height:
+ width2 = image_height - y
+
+ width = min(width1, width2)
+
+ if width >= 20:
+ return int(x), int(y), int(width)
+ else:
+ return None
+
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
+
+def guess_onnx_input_shape_dtype(filename):
+ dtype = np.float32
+ if "fp16" in filename:
+ dtype = np.float16
+ elif "int8" in filename:
+ dtype = np.uint8
+ input_size = (640, 640) if "yolo" in filename else (192, 256)
+ if "384" in filename:
+ input_size = (288, 384)
+ elif "256" in filename:
+ input_size = (256, 256)
+ return input_size, dtype
+
+if os.getenv('AUX_ORT_PROVIDERS'):
+ ONNX_PROVIDERS = os.getenv('AUX_ORT_PROVIDERS').split(',')
+else:
+ ONNX_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"]
+def get_ort_providers() -> List[str]:
+ providers = []
+ try:
+ import onnxruntime as ort
+ for provider in ONNX_PROVIDERS:
+ if provider in ort.get_available_providers():
+ providers.append(provider)
+ return providers
+ except:
+ return []
+
+def is_model_torchscript(model) -> bool:
+ return bool(type(model).__name__ == "RecursiveScriptModule")
+
+def get_model_type(Nodesname, filename) -> str:
+ ort_providers = list(filter(lambda x : x != "CPUExecutionProvider", get_ort_providers()))
+ if filename is None:
+ return None
+ elif ("onnx" in filename) and ort_providers:
+ print(f"{Nodesname}: Caching ONNXRuntime session {filename}...")
+ return "ort"
+ elif ("onnx" in filename):
+ print(f"{Nodesname}: Caching OpenCV DNN module {filename} on cv2.DNN...")
+ return "cv2"
+ else:
+ print(f"{Nodesname}: Caching TorchScript module {filename} on ...")
+ return "torchscript"
diff --git a/src/custom_controlnet_aux/dwpose/wholebody.py b/src/custom_controlnet_aux/dwpose/wholebody.py
new file mode 100644
index 0000000000000000000000000000000000000000..d453739ec0a0f8af9229bbfbc2ee440560e7c888
--- /dev/null
+++ b/src/custom_controlnet_aux/dwpose/wholebody.py
@@ -0,0 +1,181 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox
+from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas
+from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose
+
+from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox
+from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose
+
+from typing import List, Optional
+from .types import PoseResult, BodyResult, Keypoint
+from timeit import default_timer
+import os
+from .util import guess_onnx_input_shape_dtype, get_model_type, get_ort_providers, is_model_torchscript
+import torch
+
+class Wholebody:
+ def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"):
+ self.det_filename = det_model_path and os.path.basename(det_model_path)
+ self.pose_filename = pose_model_path and os.path.basename(pose_model_path)
+ self.det, self.pose = None, None
+ # return type: None ort cv2 torchscript
+ self.det_model_type = get_model_type("DWPose",self.det_filename)
+ self.pose_model_type = get_model_type("DWPose",self.pose_filename)
+ # Always loads to CPU to avoid building OpenCV.
+ cv2_device = 'cpu'
+ cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA
+ # You need to manually build OpenCV through cmake to work with your GPU.
+ cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA
+ ort_providers = get_ort_providers()
+
+ if self.det_model_type is None:
+ pass
+ elif self.det_model_type == "ort":
+ try:
+ import onnxruntime as ort
+ self.det = ort.InferenceSession(det_model_path, providers=ort_providers)
+ except:
+ print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
+ elif self.det_model_type == "cv2":
+ try:
+ self.det = cv2.dnn.readNetFromONNX(det_model_path)
+ self.det.setPreferableBackend(cv2_backend)
+ self.det.setPreferableTarget(cv2_providers)
+ except:
+ print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider")
+ try:
+ import onnxruntime as ort
+ self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"])
+ except:
+ print(f"Failed to load {det_model_path}, you can use other models instead")
+ else:
+ self.det = torch.jit.load(det_model_path)
+ self.det.to(torchscript_device)
+
+ if self.pose_model_type is None:
+ pass
+ elif self.pose_model_type == "ort":
+ try:
+ import onnxruntime as ort
+ self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers)
+ except:
+ print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI")
+ self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"])
+ elif self.pose_model_type == "cv2":
+ self.pose = cv2.dnn.readNetFromONNX(pose_model_path)
+ self.pose.setPreferableBackend(cv2_backend)
+ self.pose.setPreferableTarget(cv2_providers)
+ else:
+ self.pose = torch.jit.load(pose_model_path)
+ self.pose.to(torchscript_device)
+
+ if self.pose_filename is not None:
+ self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename)
+
+ def __call__(self, oriImg) -> Optional[np.ndarray]:
+ #Sacrifice accurate time measurement for compatibility
+
+ det_result = None
+
+ if self.det is None:
+ print("DWPose: No detector specified, using full image for pose estimation.") # pragma: no cover
+ det_result = []
+ else:
+ det_start = default_timer()
+ if is_model_torchscript(self.det):
+ det_result = inference_jit_yolox(self.det, oriImg, detect_classes=[0])
+ else:
+ if "yolox" in self.det_filename:
+ det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=[0], dtype=np.float32)
+ else:
+ #FP16 and INT8 YOLO NAS accept uint8 input
+ det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=[0], dtype=np.uint8)
+ print(f"DWPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms")
+ if (det_result is None) or (det_result.shape[0] == 0):
+ return None
+
+ pose_start = default_timer()
+ if is_model_torchscript(self.pose):
+ keypoints, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size)
+ else:
+ _, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename)
+ keypoints, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype)
+
+ num_subjects_log = 'full image'
+ if hasattr(det_result, 'shape') and det_result.shape[0] > 0:
+ num_subjects_log = f"{det_result.shape[0]} people"
+ elif isinstance(det_result, list) and len(det_result) > 0 and isinstance(det_result[0], (list, np.ndarray)):
+ num_subjects_log = f"{len(det_result)} people"
+
+ print(f"DWPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {num_subjects_log}\n")
+
+ keypoints_info = np.concatenate(
+ (keypoints, scores[..., None]), axis=-1)
+ # compute neck joint
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
+ # neck score when visualizing pred
+ neck[:, 2:4] = np.logical_and(
+ keypoints_info[:, 5, 2:4] > 0.3,
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
+ new_keypoints_info = np.insert(
+ keypoints_info, 17, neck, axis=1)
+ mmpose_idx = [
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
+ ]
+ openpose_idx = [
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
+ ]
+ new_keypoints_info[:, openpose_idx] = \
+ new_keypoints_info[:, mmpose_idx]
+ keypoints_info = new_keypoints_info
+
+ return keypoints_info
+
+ @staticmethod
+ def format_result(keypoints_info: Optional[np.ndarray]) -> List[PoseResult]:
+ def format_keypoint_part(
+ part: np.ndarray,
+ ) -> Optional[List[Optional[Keypoint]]]:
+ keypoints = [
+ Keypoint(x, y, score, i) if score >= 0.3 else None
+ for i, (x, y, score) in enumerate(part)
+ ]
+ return (
+ None if all(keypoint is None for keypoint in keypoints) else keypoints
+ )
+
+ def total_score(keypoints: Optional[List[Optional[Keypoint]]]) -> float:
+ return (
+ sum(keypoint.score for keypoint in keypoints if keypoint is not None)
+ if keypoints is not None
+ else 0.0
+ )
+
+ pose_results = []
+ if keypoints_info is None:
+ return pose_results
+
+ for instance in keypoints_info:
+ body_keypoints = format_keypoint_part(instance[:18]) or ([None] * 18)
+ left_hand = format_keypoint_part(instance[92:113])
+ right_hand = format_keypoint_part(instance[113:134])
+ face = format_keypoint_part(instance[24:92])
+
+ # Openpose face consists of 70 points in total, while DWPose only
+ # provides 68 points. Padding the last 2 points.
+ if face is not None:
+ # left eye
+ face.append(body_keypoints[14])
+ # right eye
+ face.append(body_keypoints[15])
+
+ body = BodyResult(
+ body_keypoints, total_score(body_keypoints), len(body_keypoints)
+ )
+ pose_results.append(PoseResult(body, left_hand, right_hand, face))
+
+ return pose_results
diff --git a/src/custom_controlnet_aux/hed/__init__.py b/src/custom_controlnet_aux/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..022c4f999303b8496dbb64bce9457ecebf728efb
--- /dev/null
+++ b/src/custom_controlnet_aux/hed/__init__.py
@@ -0,0 +1,110 @@
+# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
+# Please use this implementation in your products
+# This implementation may produce slightly different results from Saining Xie's official implementations,
+# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
+# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
+# and in this way it works better for gradio's RGB protocol
+
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, nms, resize_image_with_pad, safe_step, common_input_validate, custom_hf_download, HF_MODEL_NAME
+
+
+class DoubleConvBlock(torch.nn.Module):
+ def __init__(self, input_channel, output_channel, layer_number):
+ super().__init__()
+ self.convs = torch.nn.Sequential()
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ for i in range(1, layer_number):
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
+
+ def __call__(self, x, down_sampling=False):
+ h = x
+ if down_sampling:
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
+ for conv in self.convs:
+ h = conv(h)
+ h = torch.nn.functional.relu(h)
+ return h, self.projection(h)
+
+
+class ControlNetHED_Apache2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
+
+ def __call__(self, x):
+ h = x - self.norm
+ h, projection1 = self.block1(h)
+ h, projection2 = self.block2(h, down_sampling=True)
+ h, projection3 = self.block3(h, down_sampling=True)
+ h, projection4 = self.block4(h, down_sampling=True)
+ h, projection5 = self.block5(h, down_sampling=True)
+ return projection1, projection2, projection3, projection4, projection5
+
+class HEDdetector:
+ def __init__(self, netNetwork):
+ self.netNetwork = netNetwork
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="ControlNetHED.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ netNetwork = ControlNetHED_Apache2()
+ netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
+ netNetwork.float().eval()
+
+ return cls(netNetwork)
+
+ def to(self, device):
+ self.netNetwork.to(device)
+ self.device = device
+ return self
+
+
+ def __call__(self, input_image, detect_resolution=512, safe=False, output_type="pil", scribble=False, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ assert input_image.ndim == 3
+ H, W, C = input_image.shape
+ with torch.no_grad():
+ image_hed = torch.from_numpy(input_image).float().to(self.device)
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edges = self.netNetwork(image_hed)
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
+ edges = np.stack(edges, axis=2)
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
+ if safe:
+ edge = safe_step(edge)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = edge
+
+ if scribble:
+ detected_map = nms(detected_map, 127, 3.0)
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
+ detected_map[detected_map > 4] = 255
+ detected_map[detected_map < 255] = 0
+
+ detected_map = HWC3(remove_pad(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/leres/__init__.py b/src/custom_controlnet_aux/leres/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e348228730393dc3d2f79f571cba0c0ed95972d
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/__init__.py
@@ -0,0 +1,93 @@
+import os
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+from .leres.depthmap import estimateboost, estimateleres
+from .leres.multi_depth_model_woauxi import RelDepthModel
+from .leres.net_tools import strip_prefix_if_present
+from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
+from .pix2pix.options.test_options import TestOptions
+
+
+class LeresDetector:
+ def __init__(self, model, pix2pixmodel):
+ self.model = model
+ self.pix2pixmodel = pix2pixmodel
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="res101.pth", pix2pix_filename="latest_net_G.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
+
+ model = RelDepthModel(backbone='resnext101')
+ model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
+ del checkpoint
+
+ pix2pix_model_path = custom_hf_download(pretrained_model_or_path, pix2pix_filename)
+
+ opt = TestOptions().parse()
+ if not torch.cuda.is_available():
+ opt.gpu_ids = [] # cpu mode
+ pix2pixmodel = Pix2Pix4DepthModel(opt)
+ pix2pixmodel.save_dir = os.path.dirname(pix2pix_model_path)
+ pix2pixmodel.load_networks('latest')
+ pix2pixmodel.eval()
+
+ return cls(model, pix2pixmodel)
+
+ def to(self, device):
+ self.model.to(device)
+ # TODO - refactor pix2pix implementation to support device migration
+ # self.pix2pixmodel.to(device)
+ return self
+
+ def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ with torch.no_grad():
+ if boost:
+ depth = estimateboost(detected_map, self.model, 0, self.pix2pixmodel, max(detected_map.shape[1], detected_map.shape[0]))
+ else:
+ depth = estimateleres(detected_map, self.model, detected_map.shape[1], detected_map.shape[0])
+
+ numbytes=2
+ depth_min = depth.min()
+ depth_max = depth.max()
+ max_val = (2**(8*numbytes))-1
+
+ # check output before normalizing and mapping to 16 bit
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape)
+
+ # single channel, 16 bit image
+ depth_image = out.astype("uint16")
+
+ # convert to uint8
+ depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
+
+ # remove near
+ if thr_a != 0:
+ thr_a = ((thr_a/100)*255)
+ depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
+
+ # invert image
+ depth_image = cv2.bitwise_not(depth_image)
+
+ # remove bg
+ if thr_b != 0:
+ thr_b = ((thr_b/100)*255)
+ depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
+
+ detected_map = HWC3(remove_pad(depth_image))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/leres/LICENSE b/src/custom_controlnet_aux/leres/leres/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e0f1d07d98d4e85e684734d058dfe2515d215405
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/LICENSE
@@ -0,0 +1,23 @@
+https://github.com/thygate/stable-diffusion-webui-depthmap-script
+
+MIT License
+
+Copyright (c) 2023 Bob Thiry
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/leres/Resnet.py b/src/custom_controlnet_aux/leres/leres/Resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f12c9975c1aa05401269be3ca3dbaa56bde55581
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/Resnet.py
@@ -0,0 +1,199 @@
+import torch.nn as nn
+import torch.nn as NN
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ #self.avgpool = nn.AvgPool2d(7, stride=1)
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ features = []
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ features.append(x)
+ x = self.layer2(x)
+ features.append(x)
+ x = self.layer3(x)
+ features.append(x)
+ x = self.layer4(x)
+ features.append(x)
+
+ return features
+
+
+def resnet18(pretrained=True, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ return model
+
+
+def resnet34(pretrained=True, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ return model
+
+
+def resnet50(pretrained=True, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+
+ return model
+
+
+def resnet101(pretrained=True, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+
+ return model
+
+
+def resnet152(pretrained=True, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ return model
diff --git a/src/custom_controlnet_aux/leres/leres/Resnext_torch.py b/src/custom_controlnet_aux/leres/leres/Resnext_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af54fcc3e5b363935ef60c8aaf269110c0d6611
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/Resnext_torch.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python
+# coding: utf-8
+import torch.nn as nn
+
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+__all__ = ['resnext101_32x8d']
+
+
+model_urls = {
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # See note [TorchScript super()]
+ features = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ features.append(x)
+
+ x = self.layer2(x)
+ features.append(x)
+
+ x = self.layer3(x)
+ features.append(x)
+
+ x = self.layer4(x)
+ features.append(x)
+
+ #x = self.avgpool(x)
+ #x = torch.flatten(x, 1)
+ #x = self.fc(x)
+
+ return features
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+
+def resnext101_32x8d(pretrained=True, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ return model
+
diff --git a/src/custom_controlnet_aux/leres/leres/__init__.py b/src/custom_controlnet_aux/leres/leres/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/leres/leres/depthmap.py b/src/custom_controlnet_aux/leres/leres/depthmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc743bf4946b514a53f8d286a395e33c7b612582
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/depthmap.py
@@ -0,0 +1,548 @@
+# Author: thygate
+# https://github.com/thygate/stable-diffusion-webui-depthmap-script
+
+import gc
+from operator import getitem
+
+import cv2
+import numpy as np
+import skimage.measure
+import torch
+from torchvision.transforms import transforms
+
+from ...util import torch_gc
+
+whole_size_threshold = 1600 # R_max from the paper
+pix2pixsize = 1024
+
+def scale_torch(img):
+ """
+ Scale the image and output it in torch.tensor.
+ :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
+ :param scale: the scale factor. float
+ :return: img. [C, H, W]
+ """
+ if len(img.shape) == 2:
+ img = img[np.newaxis, :, :]
+ if img.shape[2] == 3:
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )])
+ img = transform(img.astype(np.float32))
+ else:
+ img = img.astype(np.float32)
+ img = torch.from_numpy(img)
+ return img
+
+def estimateleres(img, model, w, h):
+ device = next(iter(model.parameters())).device
+ # leres transform input
+ rgb_c = img[:, :, ::-1].copy()
+ A_resize = cv2.resize(rgb_c, (w, h))
+ img_torch = scale_torch(A_resize)[None, :, :, :]
+
+ # compute
+ with torch.no_grad():
+ img_torch = img_torch.to(device)
+ prediction = model.depth_model(img_torch)
+
+ prediction = prediction.squeeze().cpu().numpy()
+ prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ return prediction
+
+def generatemask(size):
+ # Generates a Guassian mask
+ mask = np.zeros(size, dtype=np.float32)
+ sigma = int(size[0]/16)
+ k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
+ mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
+ mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
+ mask = mask.astype(np.float32)
+ return mask
+
+def resizewithpool(img, size):
+ i_size = img.shape[0]
+ n = int(np.floor(i_size/size))
+
+ out = skimage.measure.block_reduce(img, (n, n), np.max)
+ return out
+
+def rgb2gray(rgb):
+ # Converts rgb to gray
+ return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
+
+def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
+ # Returns the R_x resolution described in section 5 of the main paper.
+
+ # Parameters:
+ # img :input rgb image
+ # basesize : size the dilation kernel which is equal to receptive field of the network.
+ # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
+ # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
+ # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
+
+ # Returns:
+ # outputsize_scale*speed_scale :The computed R_x resolution
+ # patch_scale: K parameter from section 6 of the paper
+
+ # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
+ speed_scale = 32
+ image_dim = int(min(img.shape[0:2]))
+
+ gray = rgb2gray(img)
+ grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
+ grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
+
+ # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
+ m = grad.min()
+ M = grad.max()
+ middle = m + (0.4 * (M - m))
+ grad[grad < middle] = 0
+ grad[grad >= middle] = 1
+
+ # dilation kernel with size of the receptive field
+ kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float)
+ # dilation kernel with size of the a quarter of receptive field used to compute k
+ # as described in section 6 of main paper
+ kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float)
+
+ # Output resolution limit set by the whole_size_threshold and scale_threshold.
+ threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
+
+ outputsize_scale = basesize / speed_scale
+ for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))):
+ grad_resized = resizewithpool(grad, p_size)
+ grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
+ grad_resized[grad_resized >= 0.5] = 1
+ grad_resized[grad_resized < 0.5] = 0
+
+ dilated = cv2.dilate(grad_resized, kernel, iterations=1)
+ meanvalue = (1-dilated).mean()
+ if meanvalue > confidence:
+ break
+ else:
+ outputsize_scale = p_size
+
+ grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
+ patch_scale = grad_region.mean()
+
+ return int(outputsize_scale*speed_scale), patch_scale
+
+# Generate a double-input depth estimation
+def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
+ # Generate the low resolution estimation
+ estimate1 = singleestimate(img, size1, model, net_type)
+ # Resize to the inference size of merge network.
+ estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Generate the high resolution estimation
+ estimate2 = singleestimate(img, size2, model, net_type)
+ # Resize to the inference size of merge network.
+ estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Inference on the merge model
+ pix2pixmodel.set_input(estimate1, estimate2)
+ pix2pixmodel.test()
+ visuals = pix2pixmodel.get_current_visuals()
+ prediction_mapped = visuals['fake_B']
+ prediction_mapped = (prediction_mapped+1)/2
+ prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
+ torch.max(prediction_mapped) - torch.min(prediction_mapped))
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
+
+ return prediction_mapped
+
+# Generate a single-input depth estimation
+def singleestimate(img, msize, model, net_type):
+ # if net_type == 0:
+ return estimateleres(img, model, msize, msize)
+ # else:
+ # return estimatemidasBoost(img, model, msize, msize)
+
+def applyGridpatch(blsize, stride, img, box):
+ # Extract a simple grid patch.
+ counter1 = 0
+ patch_bound_list = {}
+ for k in range(blsize, img.shape[1] - blsize, stride):
+ for j in range(blsize, img.shape[0] - blsize, stride):
+ patch_bound_list[str(counter1)] = {}
+ patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
+ patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
+ patchbounds[2] - patchbounds[0]]
+ patch_bound_list[str(counter1)]['rect'] = patch_bound
+ patch_bound_list[str(counter1)]['size'] = patch_bound[2]
+ counter1 = counter1 + 1
+ return patch_bound_list
+
+# Generating local patches to perform the local refinement described in section 6 of the main paper.
+def generatepatchs(img, base_size):
+
+ # Compute the gradients as a proxy of the contextual cues.
+ img_gray = rgb2gray(img)
+ whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\
+ np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
+
+ threshold = whole_grad[whole_grad > 0].mean()
+ whole_grad[whole_grad < threshold] = 0
+
+ # We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
+ gf = whole_grad.sum()/len(whole_grad.reshape(-1))
+ grad_integral_image = cv2.integral(whole_grad)
+
+ # Variables are selected such that the initial patch size would be the receptive field size
+ # and the stride is set to 1/3 of the receptive field size.
+ blsize = int(round(base_size/2))
+ stride = int(round(blsize*0.75))
+
+ # Get initial Grid
+ patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
+
+ # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
+ # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
+ print("Selecting patches ...")
+ patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
+
+ # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
+ # patch
+ patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
+ return patchset
+
+def getGF_fromintegral(integralimage, rect):
+ # Computes the gradient density of a given patch from the gradient integral image.
+ x1 = rect[1]
+ x2 = rect[1]+rect[3]
+ y1 = rect[0]
+ y2 = rect[0]+rect[2]
+ value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1]
+ return value
+
+# Adaptively select patches
+def adaptiveselection(integral_grad, patch_bound_list, gf):
+ patchlist = {}
+ count = 0
+ height, width = integral_grad.shape
+
+ search_step = int(32/factor)
+
+ # Go through all patches
+ for c in range(len(patch_bound_list)):
+ # Get patch
+ bbox = patch_bound_list[str(c)]['rect']
+
+ # Compute the amount of gradients present in the patch from the integral image.
+ cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3])
+
+ # Check if patching is beneficial by comparing the gradient density of the patch to
+ # the gradient density of the whole image
+ if cgf >= gf:
+ bbox_test = bbox.copy()
+ patchlist[str(count)] = {}
+
+ # Enlarge each patch until the gradient density of the patch is equal
+ # to the whole image gradient density
+ while True:
+
+ bbox_test[0] = bbox_test[0] - int(search_step/2)
+ bbox_test[1] = bbox_test[1] - int(search_step/2)
+
+ bbox_test[2] = bbox_test[2] + search_step
+ bbox_test[3] = bbox_test[3] + search_step
+
+ # Check if we are still within the image
+ if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
+ or bbox_test[0] + bbox_test[2] >= width:
+ break
+
+ # Compare gradient density
+ cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3])
+ if cgf < gf:
+ break
+ bbox = bbox_test.copy()
+
+ # Add patch to selected patches
+ patchlist[str(count)]['rect'] = bbox
+ patchlist[str(count)]['size'] = bbox[2]
+ count = count + 1
+
+ # Return selected patches
+ return patchlist
+
+def impatch(image, rect):
+ # Extract the given patch pixels from a given image.
+ w1 = rect[0]
+ h1 = rect[1]
+ w2 = w1 + rect[2]
+ h2 = h1 + rect[3]
+ image_patch = image[h1:h2, w1:w2]
+ return image_patch
+
+class ImageandPatchs:
+ def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
+ self.root_dir = root_dir
+ self.patchsinfo = patchsinfo
+ self.name = name
+ self.patchs = patchsinfo
+ self.scale = scale
+
+ self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)),
+ interpolation=cv2.INTER_CUBIC)
+
+ self.do_have_estimate = False
+ self.estimation_updated_image = None
+ self.estimation_base_image = None
+
+ def __len__(self):
+ return len(self.patchs)
+
+ def set_base_estimate(self, est):
+ self.estimation_base_image = est
+ if self.estimation_updated_image is not None:
+ self.do_have_estimate = True
+
+ def set_updated_estimate(self, est):
+ self.estimation_updated_image = est
+ if self.estimation_base_image is not None:
+ self.do_have_estimate = True
+
+ def __getitem__(self, index):
+ patch_id = int(self.patchs[index][0])
+ rect = np.array(self.patchs[index][1]['rect'])
+ msize = self.patchs[index][1]['size']
+
+ ## applying scale to rect:
+ rect = np.round(rect * self.scale)
+ rect = rect.astype('int')
+ msize = round(msize * self.scale)
+
+ patch_rgb = impatch(self.rgb_image, rect)
+ if self.do_have_estimate:
+ patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
+ patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
+ return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
+ 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
+ 'size': msize, 'id': patch_id}
+ else:
+ return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ """
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ """
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ #self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ #if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
+
+
+def estimateboost(img, model, model_type, pix2pixmodel, max_res=512, depthmap_script_boost_rmax=None):
+ global whole_size_threshold
+
+ # get settings
+ if depthmap_script_boost_rmax:
+ whole_size_threshold = depthmap_script_boost_rmax
+
+ if model_type == 0: #leres
+ net_receptive_field_size = 448
+ patch_netsize = 2 * net_receptive_field_size
+ elif model_type == 1: #dpt_beit_large_512
+ net_receptive_field_size = 512
+ patch_netsize = 2 * net_receptive_field_size
+ else: #other midas
+ net_receptive_field_size = 384
+ patch_netsize = 2 * net_receptive_field_size
+
+ gc.collect()
+ torch_gc()
+
+ # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
+ # It is arbitrarily large to avoid artifacts during rescaling for each crop.
+ mask_org = generatemask((3000, 3000))
+ mask = mask_org.copy()
+
+ # Value x of R_x defined in the section 5 of the main paper.
+ r_threshold_value = 0.2
+ #if R0:
+ # r_threshold_value = 0
+
+ input_resolution = img.shape
+ scale_threshold = 3 # Allows up-scaling with a scale up to 3
+
+ # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
+ # supplementary material.
+ whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold)
+
+ # print('wholeImage being processed in :', whole_image_optimal_size)
+
+ # Generate the base estimate using the double estimation.
+ whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel)
+
+ # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
+ # small high-density regions of the image.
+ global factor
+ factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
+ # print('Adjust factor is:', 1/factor)
+
+ # Check if Local boosting is beneficial.
+ if max_res < whole_image_optimal_size:
+ # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
+ return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
+
+ # Compute the default target resolution.
+ if img.shape[0] > img.shape[1]:
+ a = 2 * whole_image_optimal_size
+ b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
+ else:
+ a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
+ b = 2 * whole_image_optimal_size
+ b = int(round(b / factor))
+ a = int(round(a / factor))
+
+ """
+ # recompute a, b and saturate to max res.
+ if max(a,b) > max_res:
+ print('Default Res is higher than max-res: Reducing final resolution')
+ if img.shape[0] > img.shape[1]:
+ a = max_res
+ b = round(max_res * img.shape[1] / img.shape[0])
+ else:
+ a = round(max_res * img.shape[0] / img.shape[1])
+ b = max_res
+ b = int(b)
+ a = int(a)
+ """
+
+ img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
+
+ # Extract selected patches for local refinement
+ base_size = net_receptive_field_size * 2
+ patchset = generatepatchs(img, base_size)
+
+ # print('Target resolution: ', img.shape)
+
+ # Computing a scale in case user prompted to generate the results as the same resolution of the input.
+ # Notice that our method output resolution is independent of the input resolution and this parameter will only
+ # enable a scaling operation during the local patch merge implementation to generate results with the same resolution
+ # as the input.
+ """
+ if output_resolution == 1:
+ mergein_scale = input_resolution[0] / img.shape[0]
+ print('Dynamicly change merged-in resolution; scale:', mergein_scale)
+ else:
+ mergein_scale = 1
+ """
+ # always rescale to input res for now
+ mergein_scale = input_resolution[0] / img.shape[0]
+
+ imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
+ whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
+ round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
+ imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
+ imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
+
+ print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
+ print('Patches to process: '+str(len(imageandpatchs)))
+
+ # Enumerate through all patches, generate their estimations and refining the base estimate.
+ for patch_ind in range(len(imageandpatchs)):
+
+ # Get patch information
+ patch = imageandpatchs[patch_ind] # patch object
+ patch_rgb = patch['patch_rgb'] # rgb patch
+ patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
+ rect = patch['rect'] # patch size and location
+ patch_id = patch['id'] # patch ID
+ org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
+ print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect)
+
+ # We apply double estimation for patches. The high resolution value is fixed to twice the receptive
+ # field size of the network for patches to accelerate the process.
+ patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel)
+ patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+ patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Merging the patch estimation into the base estimate using our merge network:
+ # We feed the patch estimation and the same region from the updated base estimate to the merge network
+ # to generate the target estimate for the corresponding region.
+ pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
+
+ # Run merging network
+ pix2pixmodel.test()
+ visuals = pix2pixmodel.get_current_visuals()
+
+ prediction_mapped = visuals['fake_B']
+ prediction_mapped = (prediction_mapped+1)/2
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
+
+ mapped = prediction_mapped
+
+ # We use a simple linear polynomial to make sure the result of the merge network would match the values of
+ # base estimate
+ p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
+ merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
+
+ merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)
+
+ # Get patch size and location
+ w1 = rect[0]
+ h1 = rect[1]
+ w2 = w1 + rect[2]
+ h2 = h1 + rect[3]
+
+ # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
+ # and resize it to our needed size while merging the patches.
+ if mask.shape != org_size:
+ mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)
+
+ tobemergedto = imageandpatchs.estimation_updated_image
+
+ # Update the whole estimation:
+ # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
+ # blending at the boundaries of the patch region.
+ tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
+ imageandpatchs.set_updated_estimate(tobemergedto)
+
+ # output
+ return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
diff --git a/src/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py b/src/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdf35d7843e00be5d3c831d72b9ab5d64d130f93
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+from . import network_auxi as network
+from .net_tools import get_func
+
+
+class RelDepthModel(nn.Module):
+ def __init__(self, backbone='resnet50'):
+ super(RelDepthModel, self).__init__()
+ if backbone == 'resnet50':
+ encoder = 'resnet50_stride32'
+ elif backbone == 'resnext101':
+ encoder = 'resnext101_stride32x8d'
+ self.depth_model = DepthModel(encoder)
+
+ def inference(self, rgb):
+ with torch.no_grad():
+ input = rgb.to(self.depth_model.device)
+ depth = self.depth_model(input)
+ #pred_depth_out = depth - depth.min() + 0.01
+ return depth #pred_depth_out
+
+
+class DepthModel(nn.Module):
+ def __init__(self, encoder):
+ super(DepthModel, self).__init__()
+ backbone = network.__name__.split('.')[-1] + '.' + encoder
+ self.encoder_modules = get_func(backbone)()
+ self.decoder_modules = network.Decoder()
+
+ def forward(self, x):
+ lateral_out = self.encoder_modules(x)
+ out_logit = self.decoder_modules(lateral_out)
+ return out_logit
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/leres/net_tools.py b/src/custom_controlnet_aux/leres/leres/net_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..d44c794d6a81cb0309fb4873f83489de377e30a8
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/net_tools.py
@@ -0,0 +1,54 @@
+import importlib
+import torch
+import os
+from collections import OrderedDict
+
+
+def get_func(func_name):
+ """Helper to return a function object by name. func_name must identify a
+ function in this module or the path to a function relative to the base
+ 'modeling' module.
+ """
+ if func_name == '':
+ return None
+ try:
+ parts = func_name.split('.')
+ # Refers to a function in this module
+ if len(parts) == 1:
+ return globals()[parts[0]]
+ # Otherwise, assume we're referencing a module under modeling
+ module_name = 'custom_controlnet_aux.leres.leres.' + '.'.join(parts[:-1])
+ module = importlib.import_module(module_name)
+ return getattr(module, parts[-1])
+ except Exception:
+ print('Failed to f1ind function: %s', func_name)
+ raise
+
+def load_ckpt(args, depth_model, shift_model, focal_model):
+ """
+ Load checkpoint.
+ """
+ if os.path.isfile(args.load_ckpt):
+ print("loading checkpoint %s" % args.load_ckpt)
+ checkpoint = torch.load(args.load_ckpt)
+ if shift_model is not None:
+ shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
+ strict=True)
+ if focal_model is not None:
+ focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
+ strict=True)
+ depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
+ strict=True)
+ del checkpoint
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+def strip_prefix_if_present(state_dict, prefix):
+ keys = sorted(state_dict.keys())
+ if not all(key.startswith(prefix) for key in keys):
+ return state_dict
+ stripped_state_dict = OrderedDict()
+ for key, value in state_dict.items():
+ stripped_state_dict[key.replace(prefix, "")] = value
+ return stripped_state_dict
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/leres/network_auxi.py b/src/custom_controlnet_aux/leres/leres/network_auxi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd87011a5339aca632d1a10b217c8737bdc794f
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/leres/network_auxi.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+from . import Resnet, Resnext_torch
+
+
+def resnet50_stride32():
+ return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2])
+
+def resnext101_stride32x8d():
+ return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2])
+
+
+class Decoder(nn.Module):
+ def __init__(self):
+ super(Decoder, self).__init__()
+ self.inchannels = [256, 512, 1024, 2048]
+ self.midchannels = [256, 256, 256, 512]
+ self.upfactors = [2,2,2,2]
+ self.outchannels = 1
+
+ self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
+ self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
+ self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
+
+ self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
+ self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
+ self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
+
+ self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2)
+ self._init_params()
+
+ def _init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, features):
+ x_32x = self.conv(features[3]) # 1/32
+ x_32 = self.conv1(x_32x)
+ x_16 = self.upsample(x_32) # 1/16
+
+ x_8 = self.ffm2(features[2], x_16) # 1/8
+ x_4 = self.ffm1(features[1], x_8) # 1/4
+ x_2 = self.ffm0(features[0], x_4) # 1/2
+ #-----------------------------------------
+ x = self.outconv(x_2) # original size
+ return x
+
+class DepthNet(nn.Module):
+ __factory = {
+ 18: Resnet.resnet18,
+ 34: Resnet.resnet34,
+ 50: Resnet.resnet50,
+ 101: Resnet.resnet101,
+ 152: Resnet.resnet152
+ }
+ def __init__(self,
+ backbone='resnet',
+ depth=50,
+ upfactors=[2, 2, 2, 2]):
+ super(DepthNet, self).__init__()
+ self.backbone = backbone
+ self.depth = depth
+ self.pretrained = False
+ self.inchannels = [256, 512, 1024, 2048]
+ self.midchannels = [256, 256, 256, 512]
+ self.upfactors = upfactors
+ self.outchannels = 1
+
+ # Build model
+ if self.backbone == 'resnet':
+ if self.depth not in DepthNet.__factory:
+ raise KeyError("Unsupported depth:", self.depth)
+ self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained)
+ elif self.backbone == 'resnext101_32x8d':
+ self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained)
+ else:
+ self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained)
+
+ def forward(self, x):
+ x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4
+ return x
+
+
+class FTB(nn.Module):
+ def __init__(self, inchannels, midchannels=512):
+ super(FTB, self).__init__()
+ self.in1 = inchannels
+ self.mid = midchannels
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1,
+ bias=True)
+ # NN.BatchNorm2d
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
+ padding=1, stride=1, bias=True), \
+ nn.BatchNorm2d(num_features=self.mid), \
+ nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
+ padding=1, stride=1, bias=True))
+ self.relu = nn.ReLU(inplace=True)
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = x + self.conv_branch(x)
+ x = self.relu(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class ATA(nn.Module):
+ def __init__(self, inchannels, reduction=8):
+ super(ATA, self).__init__()
+ self.inchannels = inchannels
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.inchannels // reduction, self.inchannels),
+ nn.Sigmoid())
+ self.init_params()
+
+ def forward(self, low_x, high_x):
+ n, c, _, _ = low_x.size()
+ x = torch.cat([low_x, high_x], 1)
+ x = self.avg_pool(x)
+ x = x.view(n, -1)
+ x = self.fc(x).view(n, c, 1, 1)
+ x = low_x * x + high_x
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ # init.normal(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ # init.normal_(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FFM(nn.Module):
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
+ super(FFM, self).__init__()
+ self.inchannels = inchannels
+ self.midchannels = midchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
+ # self.ata = ATA(inchannels = self.midchannels)
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
+
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
+
+ self.init_params()
+
+ def forward(self, low_x, high_x):
+ x = self.ftb1(low_x)
+ x = x + high_x
+ x = self.ftb2(x)
+ x = self.upsample(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class AO(nn.Module):
+ # Adaptive output module
+ def __init__(self, inchannels, outchannels, upfactor=2):
+ super(AO, self).__init__()
+ self.inchannels = inchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.adapt_conv = nn.Sequential(
+ nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1,
+ stride=1, bias=True), \
+ nn.BatchNorm2d(num_features=self.inchannels // 2), \
+ nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1,
+ stride=1, bias=True), \
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.adapt_conv(x)
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+
+# ==============================================================================================================
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, inchannels):
+ super(ResidualConv, self).__init__()
+ # NN.BatchNorm2d
+ self.conv = nn.Sequential(
+ # nn.BatchNorm2d(num_features=inchannels),
+ nn.ReLU(inplace=False),
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1,
+ bias=False),
+ nn.BatchNorm2d(num_features=inchannels / 2),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1,
+ bias=False)
+ )
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv(x) + x
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FeatureFusion(nn.Module):
+ def __init__(self, inchannels, outchannels):
+ super(FeatureFusion, self).__init__()
+ self.conv = ResidualConv(inchannels=inchannels)
+ # NN.BatchNorm2d
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,
+ stride=2, padding=1, output_padding=1),
+ nn.BatchNorm2d(num_features=outchannels),
+ nn.ReLU(inplace=True))
+
+ def forward(self, lowfeat, highfeat):
+ return self.up(highfeat + self.conv(lowfeat))
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class SenceUnderstand(nn.Module):
+ def __init__(self, channels):
+ super(SenceUnderstand, self).__init__()
+ self.channels = channels
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True))
+ self.pool = nn.AdaptiveAvgPool2d(8)
+ self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels),
+ nn.ReLU(inplace=True))
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
+ nn.ReLU(inplace=True))
+ self.initial_params()
+
+ def forward(self, x):
+ n, c, h, w = x.size()
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = x.view(n, -1)
+ x = self.fc(x)
+ x = x.view(n, self.channels, 1, 1)
+ x = self.conv2(x)
+ x = x.repeat(1, 1, h, w)
+ return x
+
+ def initial_params(self, dev=0.01):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, dev)
+
+
+if __name__ == '__main__':
+ net = DepthNet(depth=50, pretrained=True)
+ print(net)
+ inputs = torch.ones(4,3,128,128)
+ out = net(inputs)
+ print(out.size())
+
diff --git a/src/custom_controlnet_aux/leres/pix2pix/LICENSE b/src/custom_controlnet_aux/leres/pix2pix/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..38b1a24fd389a138b930dcf1ee606ef97a0186c8
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/LICENSE
@@ -0,0 +1,19 @@
+https://github.com/compphoto/BoostingMonocularDepth
+
+Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved.
+
+This software is for academic use only. A redistribution of this
+software, with or without modifications, has to be for academic
+use only, while giving the appropriate credit to the original
+authors of the software. The methods implemented as a part of
+this software may be covered under patents or patent applications.
+
+THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/pix2pix/__init__.py b/src/custom_controlnet_aux/leres/pix2pix/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/leres/pix2pix/models/__init__.py b/src/custom_controlnet_aux/leres/pix2pix/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb83d333a655103146dde012f119f05c05635c3e
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from .base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "custom_controlnet_aux.leres.pix2pix.models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/src/custom_controlnet_aux/leres/pix2pix/models/base_model.py b/src/custom_controlnet_aux/leres/pix2pix/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ec298f77cf769e39da38d1107e0b6dc38d519d
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/models/base_model.py
@@ -0,0 +1,244 @@
+import gc
+import os
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+
+import torch
+
+from ....util import torch_gc
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this function, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
+ torch.backends.cudnn.benchmark = True
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+ if not self.isTrain or opt.continue_train:
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
+ self.load_networks(load_suffix)
+ self.print_networks(opt.verbose)
+
+ def eval(self):
+ """Make models eval mode during test time"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ old_lr = self.optimizers[0].param_groups[0]['lr']
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ save_filename = '%s_net_%s.pth' % (epoch, name)
+ save_path = os.path.join(self.save_dir, save_filename)
+ net = getattr(self, 'net' + name)
+
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+ torch.save(net.module.cpu().state_dict(), save_path)
+ net.cuda(self.gpu_ids[0])
+ else:
+ torch.save(net.cpu().state_dict(), save_path)
+
+ def unload_network(self, name):
+ """Unload network and gc.
+ """
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ del net
+ gc.collect()
+ torch_gc()
+ return None
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ load_filename = '%s_net_%s.pth' % (epoch, name)
+ load_path = os.path.join(self.save_dir, load_filename)
+ net = getattr(self, 'net' + name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ # print('Loading depth boost model from %s' % load_path)
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=str(self.device))
+ if hasattr(state_dict, '_metadata'):
+ del state_dict._metadata
+
+ # patch InstanceNorm checkpoints prior to 0.4
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+ net.load_state_dict(state_dict)
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
diff --git a/src/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py b/src/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1709accdf0b048b3793dfd1f58d1b06c35f7b907
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py
@@ -0,0 +1,58 @@
+import os
+import torch
+
+class BaseModelHG():
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, gpu_ids):
+ save_filename = '_%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if len(gpu_ids) and torch.cuda.is_available():
+ network.cuda(device_id=gpu_ids[0])
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ print(save_path)
+ model = torch.load(save_path)
+ return model
+ # network.load_state_dict(torch.load(save_path))
+
+ def update_learning_rate():
+ pass
diff --git a/src/custom_controlnet_aux/leres/pix2pix/models/networks.py b/src/custom_controlnet_aux/leres/pix2pix/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf912b2973721a02deefd042af621e732bad59f
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/models/networks.py
@@ -0,0 +1,623 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+
+
+###############################################################################
+# Helper Functions
+###############################################################################
+
+
+class Identity(nn.Module):
+ def forward(self, x):
+ return x
+
+
+def get_norm_layer(norm_type='instance'):
+ """Return a normalization layer
+
+ Parameters:
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
+
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
+ """
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ return norm_layer
+
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For 'linear', we keep the same learning rate for the first epochs
+ and linearly decay the rate to zero over the next epochs.
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def init_weights(net, init_type='normal', init_gain=0.02):
+ """Initialize network weights.
+
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ # print('initialize network with %s' % init_type)
+ net.apply(init_func) # apply the initialization function
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+ Parameters:
+ net (network) -- the network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Return an initialized network.
+ """
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ net.to(gpu_ids[0])
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
+ init_weights(net, init_type, init_gain=init_gain)
+ return net
+
+
+def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Create a generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
+ use_dropout (bool) -- if use dropout layers.
+ init_type (str) -- the name of our initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a generator
+
+ Our current implementation provides two types of generators:
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
+
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
+
+
+ The generator has been initialized by . It uses RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netG == 'resnet_9blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
+ elif netG == 'resnet_6blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
+ elif netG == 'resnet_12blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12)
+ elif netG == 'unet_128':
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_256':
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_672':
+ net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_960':
+ net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_1024':
+ net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ else:
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Create a discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the first conv layer
+ netD (str) -- the architecture's name: basic | n_layers | pixel
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
+ norm (str) -- the type of normalization layers used in the network.
+ init_type (str) -- the name of the initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a discriminator
+
+ Our current implementation provides three types of discriminators:
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
+ It can classify whether 70×70 overlapping patches are real or fake.
+ Such a patch-level discriminator architecture has fewer parameters
+ than a full-image discriminator and can work on arbitrarily-sized images
+ in a fully convolutional fashion.
+
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
+ with the parameter (default=3 as used in [basic] (PatchGAN).)
+
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
+ It encourages greater color diversity but has no effect on spatial statistics.
+
+ The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netD == 'basic': # default PatchGAN classifier
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
+ elif netD == 'n_layers': # more options
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
+ elif netD == 'pixel': # classify if each pixel is real or fake
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
+ else:
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+##############################################################################
+# Classes
+##############################################################################
+class GANLoss(nn.Module):
+ """Define different GAN objectives.
+
+ The GANLoss class abstracts away the need to create the target label tensor
+ that has the same size as the input.
+ """
+
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
+ """ Initialize the GANLoss class.
+
+ Parameters:
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
+ target_real_label (bool) - - label for a real image
+ target_fake_label (bool) - - label of a fake image
+
+ Note: Do not use sigmoid as the last layer of Discriminator.
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
+ """
+ super(GANLoss, self).__init__()
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+ self.gan_mode = gan_mode
+ if gan_mode == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif gan_mode == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif gan_mode in ['wgangp']:
+ self.loss = None
+ else:
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
+
+ def get_target_tensor(self, prediction, target_is_real):
+ """Create label tensors with the same size as the input.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ A label tensor filled with ground truth label, and with the size of the input
+ """
+
+ if target_is_real:
+ target_tensor = self.real_label
+ else:
+ target_tensor = self.fake_label
+ return target_tensor.expand_as(prediction)
+
+ def __call__(self, prediction, target_is_real):
+ """Calculate loss given Discriminator's output and grount truth labels.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction output from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ the calculated loss.
+ """
+ if self.gan_mode in ['lsgan', 'vanilla']:
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
+ loss = self.loss(prediction, target_tensor)
+ elif self.gan_mode == 'wgangp':
+ if target_is_real:
+ loss = -prediction.mean()
+ else:
+ loss = prediction.mean()
+ return loss
+
+
+def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
+
+ Arguments:
+ netD (network) -- discriminator network
+ real_data (tensor array) -- real images
+ fake_data (tensor array) -- generated images from the generator
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
+ constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
+ lambda_gp (float) -- weight for this loss
+
+ Returns the gradient penalty loss
+ """
+ if lambda_gp > 0.0:
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
+ interpolatesv = real_data
+ elif type == 'fake':
+ interpolatesv = fake_data
+ elif type == 'mixed':
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
+ else:
+ raise NotImplementedError('{} not implemented'.format(type))
+ interpolatesv.requires_grad_(True)
+ disc_interpolates = netD(interpolatesv)
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
+ return gradient_penalty, gradients
+ else:
+ return 0.0, None
+
+
+class ResnetGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock(nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Initialize the Resnet block
+
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Construct a convolutional block.
+
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
+ """
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.model(input)
+
+
+class PixelDiscriminator(nn.Module):
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
+
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
+ """Construct a 1x1 PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ """
+ super(PixelDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.net(input)
diff --git a/src/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py b/src/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e89652feb96314973a050c5a2477b474630abb
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py
@@ -0,0 +1,155 @@
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class Pix2Pix4DepthModel(BaseModel):
+ """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
+
+ The model training requires '--dataset_mode aligned' dataset.
+ By default, it uses a '--netG unet256' U-Net generator,
+ a '--netD basic' discriminator (PatchGAN),
+ and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
+
+ pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+
+ For pix2pix, we do not use image buffer
+ The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
+ By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
+ """
+ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
+ parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge')
+ if is_train:
+ parser.set_defaults(pool_size=0, gan_mode='vanilla',)
+ parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss')
+ return parser
+
+ def __init__(self, opt):
+ """Initialize the pix2pix class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseModel.__init__(self, opt)
+ # specify the training losses you want to print out. The training/test scripts will call
+
+ self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
+ # self.loss_names = ['G_L1']
+
+ # specify the images you want to save/display. The training/test scripts will call
+ if self.isTrain:
+ self.visual_names = ['outer','inner', 'fake_B', 'real_B']
+ else:
+ self.visual_names = ['fake_B']
+
+ # specify the models you want to save to the disk. The training/test scripts will call and
+ if self.isTrain:
+ self.model_names = ['G','D']
+ else: # during test time, only load G
+ self.model_names = ['G']
+
+ # define networks (both generator and discriminator)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none',
+ False, 'normal', 0.02, self.gpu_ids)
+
+ if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+
+ if self.isTrain:
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionL1 = torch.nn.L1Loss()
+ # initialize optimizers; schedulers will be automatically created by function .
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999))
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+
+ def set_input_train(self, input):
+ self.outer = input['data_outer'].to(self.device)
+ self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False)
+
+ self.inner = input['data_inner'].to(self.device)
+ self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False)
+
+ self.image_paths = input['image_path']
+
+ if self.isTrain:
+ self.gtfake = input['data_gtfake'].to(self.device)
+ self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False)
+ self.real_B = self.gtfake
+
+ self.real_A = torch.cat((self.outer, self.inner), 1)
+
+ def set_input(self, outer, inner):
+ inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
+ outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
+
+ inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
+ outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))
+
+ inner = self.normalize(inner)
+ outer = self.normalize(outer)
+
+ self.real_A = torch.cat((outer, inner), 1).to(self.device)
+
+
+ def normalize(self, input):
+ input = input * 2
+ input = input - 1
+ return input
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ self.fake_B = self.netG(self.real_A) # G(A)
+
+ def backward_D(self):
+ """Calculate GAN loss for the discriminator"""
+ # Fake; stop backprop to the generator by detaching fake_B
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
+ pred_fake = self.netD(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Real
+ real_AB = torch.cat((self.real_A, self.real_B), 1)
+ pred_real = self.netD(real_AB)
+ self.loss_D_real = self.criterionGAN(pred_real, True)
+ # combine loss and calculate gradients
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+ self.loss_D.backward()
+
+ def backward_G(self):
+ """Calculate GAN and L1 loss for the generator"""
+ # First, G(A) should fake the discriminator
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
+ pred_fake = self.netD(fake_AB)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
+ # Second, G(A) = B
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
+ # combine loss and calculate gradients
+ self.loss_G = self.loss_G_L1 + self.loss_G_GAN
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ self.forward() # compute fake images: G(A)
+ # update D
+ self.set_requires_grad(self.netD, True) # enable backprop for D
+ self.optimizer_D.zero_grad() # set D's gradients to zero
+ self.backward_D() # calculate gradients for D
+ self.optimizer_D.step() # update D's weights
+ # update G
+ self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
+ self.optimizer_G.zero_grad() # set G's gradients to zero
+ self.backward_G() # calculate graidents for G
+ self.optimizer_G.step() # udpate G's weights
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/leres/pix2pix/options/__init__.py b/src/custom_controlnet_aux/leres/pix2pix/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/src/custom_controlnet_aux/leres/pix2pix/options/base_options.py b/src/custom_controlnet_aux/leres/pix2pix/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..533a1e88a7e8494223f6994e6861c93667754f83
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/options/base_options.py
@@ -0,0 +1,156 @@
+import argparse
+import os
+from ...pix2pix.util import util
+# import torch
+from ...pix2pix import models
+# import pix2pix.data
+import numpy as np
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here')
+ # model parameters
+ parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
+ parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
+ parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
+ # dataset parameters
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
+ parser.add_argument('--load_size', type=int, default=672, help='scale images to this size')
+ parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size')
+ parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ parser.add_argument('--data_dir', type=str, required=False,
+ help='input files directory images can be .png .jpg .tiff')
+ parser.add_argument('--output_dir', type=str, required=False,
+ help='result dir. result depth will be png. vides are JMPG as avi')
+ parser.add_argument('--savecrops', type=int, required=False)
+ parser.add_argument('--savewholeest', type=int, required=False)
+ parser.add_argument('--output_resolution', type=int, required=False,
+ help='0 for no restriction 1 for resize to input size')
+ parser.add_argument('--net_receptive_field_size', type=int, required=False)
+ parser.add_argument('--pix2pixsize', type=int, required=False)
+ parser.add_argument('--generatevideo', type=int, required=False)
+ parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL')
+ parser.add_argument('--R0', action='store_true')
+ parser.add_argument('--R20', action='store_true')
+ parser.add_argument('--Final', action='store_true')
+ parser.add_argument('--colorize_results', action='store_true')
+ parser.add_argument('--max_res', type=float, default=np.inf)
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ opt, _ = parser.parse_known_args()
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+
+ # modify dataset-related parser options
+ # dataset_name = opt.dataset_mode
+ # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name)
+ # parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ #return parser.parse_args() #EVIL
+ return opt
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ #self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ #if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
diff --git a/src/custom_controlnet_aux/leres/pix2pix/options/test_options.py b/src/custom_controlnet_aux/leres/pix2pix/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3424b5e3b66d6813f74c8cecad691d7488d121c
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/options/test_options.py
@@ -0,0 +1,22 @@
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ # Dropout and Batchnorm has different behavioir during training and test.
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
+ # rewrite devalue values
+ parser.set_defaults(model='pix2pix4depth')
+ # To avoid cropping, the load_size should be the same as crop_size
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
+ self.isTrain = False
+ return parser
diff --git a/src/custom_controlnet_aux/leres/pix2pix/util/__init__.py b/src/custom_controlnet_aux/leres/pix2pix/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/util/__init__.py
@@ -0,0 +1 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
diff --git a/src/custom_controlnet_aux/leres/pix2pix/util/util.py b/src/custom_controlnet_aux/leres/pix2pix/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a7aceaa00681cb76675df7866bf8db58c8d2caf
--- /dev/null
+++ b/src/custom_controlnet_aux/leres/pix2pix/util/util.py
@@ -0,0 +1,105 @@
+"""This module contains simple helper functions """
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import os
+
+
+def tensor2im(input_image, imtype=np.uint16):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array
+ image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) #
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+ image_pil = Image.fromarray(image_numpy)
+
+ image_pil = image_pil.convert('I;16')
+
+ # image_pil = Image.fromarray(image_numpy)
+ # h, w, _ = image_numpy.shape
+ #
+ # if aspect_ratio > 1.0:
+ # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ # if aspect_ratio < 1.0:
+ # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
diff --git a/src/custom_controlnet_aux/lineart/LICENSE b/src/custom_controlnet_aux/lineart/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/src/custom_controlnet_aux/lineart/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/lineart/__init__.py b/src/custom_controlnet_aux/lineart/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42dc501f67f8fa333b783883373ddb271bfceaf
--- /dev/null
+++ b/src/custom_controlnet_aux/lineart/__init__.py
@@ -0,0 +1,141 @@
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME
+
+norm_layer = nn.InstanceNorm2d
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_features):
+ super(ResidualBlock, self).__init__()
+
+ conv_block = [ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features),
+ nn.ReLU(inplace=True),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features)
+ ]
+
+ self.conv_block = nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ return x + self.conv_block(x)
+
+
+class Generator(nn.Module):
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
+ super(Generator, self).__init__()
+
+ # Initial convolution block
+ model0 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, 64, 7),
+ norm_layer(64),
+ nn.ReLU(inplace=True) ]
+ self.model0 = nn.Sequential(*model0)
+
+ # Downsampling
+ model1 = []
+ in_features = 64
+ out_features = in_features*2
+ for _ in range(2):
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features*2
+ self.model1 = nn.Sequential(*model1)
+
+ model2 = []
+ # Residual blocks
+ for _ in range(n_residual_blocks):
+ model2 += [ResidualBlock(in_features)]
+ self.model2 = nn.Sequential(*model2)
+
+ # Upsampling
+ model3 = []
+ out_features = in_features//2
+ for _ in range(2):
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features//2
+ self.model3 = nn.Sequential(*model3)
+
+ # Output layer
+ model4 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(64, output_nc, 7)]
+ if sigmoid:
+ model4 += [nn.Sigmoid()]
+
+ self.model4 = nn.Sequential(*model4)
+
+ def forward(self, x, cond=None):
+ out = self.model0(x)
+ out = self.model1(out)
+ out = self.model2(out)
+ out = self.model3(out)
+ out = self.model4(out)
+
+ return out
+
+
+class LineartDetector:
+ def __init__(self, model, coarse_model):
+ self.model = model
+ self.model_coarse = coarse_model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="sk_model.pth", coarse_filename="sk_model2.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ coarse_model_path = custom_hf_download(pretrained_model_or_path, coarse_filename)
+
+ model = Generator(3, 1, 3)
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
+ model.eval()
+
+ coarse_model = Generator(3, 1, 3)
+ coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu')))
+ coarse_model.eval()
+
+ return cls(model, coarse_model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.model_coarse.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, coarse=False, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ model = self.model_coarse if coarse else self.model
+ assert detected_map.ndim == 3
+ with torch.no_grad():
+ image = torch.from_numpy(detected_map).float().to(self.device)
+ image = image / 255.0
+ image = rearrange(image, 'h w c -> 1 c h w')
+ line = model(image)[0][0]
+
+ line = line.cpu().numpy()
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = HWC3(line)
+ detected_map = remove_pad(255 - detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/lineart_anime/LICENSE b/src/custom_controlnet_aux/lineart_anime/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/src/custom_controlnet_aux/lineart_anime/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/lineart_anime/__init__.py b/src/custom_controlnet_aux/lineart_anime/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42dfeb3a5dd504121c4d22aa0bab7a16dc681ed
--- /dev/null
+++ b/src/custom_controlnet_aux/lineart_anime/__init__.py
@@ -0,0 +1,167 @@
+import functools
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class LineartAnimeDetector:
+ def __init__(self, model):
+ self.model = model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="netG.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
+ ckpt = torch.load(model_path)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ net.load_state_dict(ckpt)
+ net.eval()
+
+ return cls(net)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ H, W, C = input_image.shape
+ Hn = 256 * int(np.ceil(float(H) / 256.0))
+ Wn = 256 * int(np.ceil(float(W) / 256.0))
+ input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
+
+ with torch.no_grad():
+ image_feed = torch.from_numpy(input_image).float().to(self.device)
+ image_feed = image_feed / 127.5 - 1.0
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
+
+ line = self.model(image_feed)[0, 0] * 127.5 + 127.5
+ line = line.cpu().numpy()
+ line = line.clip(0, 255).astype(np.uint8)
+
+ #A1111 uses INTER AREA for downscaling so ig that is the best choice
+ detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_AREA)
+ detected_map = remove_pad(255 - detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/lineart_standard/__init__.py b/src/custom_controlnet_aux/lineart_standard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7932e733cf8147a7a500c86ed3eb192a23504626
--- /dev/null
+++ b/src/custom_controlnet_aux/lineart_standard/__init__.py
@@ -0,0 +1,21 @@
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
+
+class LineartStandardDetector:
+ def __call__(self, input_image=None, guassian_sigma=6.0, intensity_threshold=8, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ x = input_image.astype(np.float32)
+ g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
+ intensity = np.min(g - x, axis=2).clip(0, 255)
+ intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
+ intensity *= 127
+ detected_map = intensity.clip(0, 255).astype(np.uint8)
+
+ detected_map = HWC3(remove_pad(detected_map))
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/manga_line/LICENSE b/src/custom_controlnet_aux/manga_line/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..bdca75a54d05781782e3d939401e93161cdd88f7
--- /dev/null
+++ b/src/custom_controlnet_aux/manga_line/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Miaomiao Li
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/manga_line/__init__.py b/src/custom_controlnet_aux/manga_line/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16dea641c70110d920b2bcd72754bb94c4ff087f
--- /dev/null
+++ b/src/custom_controlnet_aux/manga_line/__init__.py
@@ -0,0 +1,63 @@
+# MangaLineExtraction_PyTorch
+# https://github.com/ljsabc/MangaLineExtraction_PyTorch
+
+#NOTE: This preprocessor is designed to work with lineart_anime ControlNet so the result will be white lines on black canvas
+
+import torch
+import numpy as np
+import os
+import cv2
+from einops import rearrange
+from .model_torch import res_skip
+from PIL import Image
+import warnings
+
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME
+
+class LineartMangaDetector:
+ def __init__(self, model):
+ self.model = model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="erika.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ net = res_skip()
+ ckpt = torch.load(model_path)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ net.load_state_dict(ckpt)
+ net.eval()
+ return cls(net)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, 256 * int(np.ceil(float(detect_resolution) / 256.0)), upscale_method)
+
+ img = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
+ with torch.no_grad():
+ image_feed = torch.from_numpy(img).float().to(self.device)
+ image_feed = rearrange(image_feed, 'h w -> 1 1 h w')
+
+ line = self.model(image_feed)
+ line = line.cpu().numpy()[0,0,:,:]
+ line[line > 255] = 255
+ line[line < 0] = 0
+
+ line = line.astype(np.uint8)
+
+ detected_map = HWC3(line)
+ detected_map = remove_pad(255 - detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/manga_line/model_torch.py b/src/custom_controlnet_aux/manga_line/model_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..de5828ccc486d74490b8da710d644651067bd5f3
--- /dev/null
+++ b/src/custom_controlnet_aux/manga_line/model_torch.py
@@ -0,0 +1,196 @@
+import torch.nn as nn
+import numpy as np
+
+#torch.set_printoptions(precision=10)
+
+
+class _bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+ # the following are for debugs
+ print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ for i,layer in enumerate(self.model):
+ if i != 2:
+ x = layer(x)
+ else:
+ x = layer(x)
+ #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
+ print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ print(x[0])
+ return x
+
+
+class _u_bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_u_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+
+
+class _shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample=1):
+ super(_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters or subsample != 1:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
+ )
+
+ def forward(self, x, y):
+ #print(x.size(), y.size(), self.process)
+ if self.process:
+ y0 = self.model(x)
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
+ return y0 + y
+ else:
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
+ return x + y
+
+class _u_shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample):
+ super(_u_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x, y):
+ if self.process:
+ return self.model(x) + y
+ else:
+ return x + y
+
+
+class basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(basic_block, self).__init__()
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.residual(x1)
+ return self.shortcut(x, x2)
+
+class _u_basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(_u_basic_block, self).__init__()
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ y = self.residual(self.conv1(x))
+ return self.shortcut(x, y)
+
+
+class _residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
+ super(_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ init_subsample = 1
+ if i == repetitions - 1 and not is_first_layer:
+ init_subsample = 2
+ if i == 0:
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class _upsampling_residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions):
+ super(_upsampling_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ l = None
+ if i == 0:
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class res_skip(nn.Module):
+
+ def __init__(self):
+ super(res_skip, self).__init__()
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
+
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
+ self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))
+
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
+ self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))
+
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
+ self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))
+
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
+ self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))
+
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)
+
+ def forward(self, x):
+ x0 = self.block0(x)
+ x1 = self.block1(x0)
+ x2 = self.block2(x1)
+ x3 = self.block3(x2)
+ x4 = self.block4(x3)
+
+ x5 = self.block5(x4)
+ res1 = self.res1(x3, x5)
+
+ x6 = self.block6(res1)
+ res2 = self.res2(x2, x6)
+
+ x7 = self.block7(res2)
+ res3 = self.res3(x1, x7)
+
+ x8 = self.block8(res3)
+ res4 = self.res4(x0, x8)
+
+ x9 = self.block9(res4)
+ y = self.conv15(x9)
+
+ return y
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mediapipe_face/__init__.py b/src/custom_controlnet_aux/mediapipe_face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..607db69d6d76224f8450c90212e766d21bf0d33b
--- /dev/null
+++ b/src/custom_controlnet_aux/mediapipe_face/__init__.py
@@ -0,0 +1,31 @@
+import warnings
+from typing import Union
+
+import cv2
+import numpy as np
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad
+from .mediapipe_face_common import generate_annotation
+
+
+class MediapipeFaceDetector:
+ def __call__(self,
+ input_image: Union[np.ndarray, Image.Image] = None,
+ max_faces: int = 1,
+ min_confidence: float = 0.5,
+ output_type: str = "pil",
+ detect_resolution: int = 512,
+ image_resolution: int = 512,
+ upscale_method="INTER_CUBIC",
+ **kwargs):
+
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ detected_map = generate_annotation(detected_map, max_faces, min_confidence)
+ detected_map = remove_pad(HWC3(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py b/src/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..32eeaf7455df2dd9efa5976def5e617b08757598
--- /dev/null
+++ b/src/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py
@@ -0,0 +1,156 @@
+from typing import Mapping
+import warnings
+
+import mediapipe as mp
+import numpy
+
+if mp:
+ mp_drawing = mp.solutions.drawing_utils
+ mp_drawing_styles = mp.solutions.drawing_styles
+ mp_face_detection = mp.solutions.face_detection # Only for counting faces.
+ mp_face_mesh = mp.solutions.face_mesh
+ mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
+ mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
+ mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
+
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+ PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
+
+ min_face_size_pixels: int = 64
+ f_thick = 2
+ f_rad = 1
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+ face_connection_spec = {}
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ face_connection_spec[edge] = left_eye_draw
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ face_connection_spec[edge] = left_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+ # face_connection_spec[edge] = left_iris_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ face_connection_spec[edge] = right_eye_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ face_connection_spec[edge] = right_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ # face_connection_spec[edge] = right_iris_draw
+ for edge in mp_face_mesh.FACEMESH_LIPS:
+ face_connection_spec[edge] = mouth_draw
+ iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+
+
+def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError('Input image must contain three channel bgr data.')
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if (
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
+ (landmark.HasField('presence') and landmark.presence < 0.5)
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+ image_x = int(image_cols*landmark.x)
+ image_y = int(image_rows*landmark.y)
+ draw_color = None
+ if isinstance(drawing_spec, Mapping):
+ if drawing_spec.get(idx) is None:
+ continue
+ else:
+ draw_color = drawing_spec[idx].color
+ elif isinstance(drawing_spec, DrawingSpec):
+ draw_color = drawing_spec.color
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
+
+
+def reverse_channels(image):
+ """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
+ # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
+ # im[:,:,::[2,1,0]] would also work but makes a copy of the data.
+ return image[:, :, ::-1]
+
+
+def generate_annotation(
+ img_rgb,
+ max_faces: int,
+ min_confidence: float
+):
+ """
+ Find up to 'max_faces' inside the provided input image.
+ If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
+ pixels in the image.
+ """
+ with mp_face_mesh.FaceMesh(
+ static_image_mode=True,
+ max_num_faces=max_faces,
+ refine_landmarks=True,
+ min_detection_confidence=min_confidence,
+ ) as facemesh:
+ img_height, img_width, img_channels = img_rgb.shape
+ assert(img_channels == 3)
+
+ results = facemesh.process(img_rgb).multi_face_landmarks
+
+ if results is None:
+ print("No faces detected in controlnet image for Mediapipe face annotator.")
+ return numpy.zeros_like(img_rgb)
+
+ # Filter faces that are too small
+ filtered_landmarks = []
+ for lm in results:
+ landmarks = lm.landmark
+ face_rect = [
+ landmarks[0].x,
+ landmarks[0].y,
+ landmarks[0].x,
+ landmarks[0].y,
+ ] # Left, up, right, down.
+ for i in range(len(landmarks)):
+ face_rect[0] = min(face_rect[0], landmarks[i].x)
+ face_rect[1] = min(face_rect[1], landmarks[i].y)
+ face_rect[2] = max(face_rect[2], landmarks[i].x)
+ face_rect[3] = max(face_rect[3], landmarks[i].y)
+ if min_face_size_pixels > 0:
+ face_width = abs(face_rect[2] - face_rect[0])
+ face_height = abs(face_rect[3] - face_rect[1])
+ face_width_pixels = face_width * img_width
+ face_height_pixels = face_height * img_height
+ face_size = min(face_width_pixels, face_height_pixels)
+ if face_size >= min_face_size_pixels:
+ filtered_landmarks.append(lm)
+ else:
+ filtered_landmarks.append(lm)
+
+ # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
+ empty = numpy.zeros_like(img_rgb)
+
+ # Draw detected faces:
+ for face_landmarks in filtered_landmarks:
+ mp_drawing.draw_landmarks(
+ empty,
+ face_landmarks,
+ connections=face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=face_connection_spec
+ )
+ draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
+
+ # Flip BGR back to RGB.
+ empty = reverse_channels(empty).copy()
+
+ return empty
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mesh_graphormer/__init__.py b/src/custom_controlnet_aux/mesh_graphormer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7a35976d1936a2a9356894fadbcede71c6eff25
--- /dev/null
+++ b/src/custom_controlnet_aux/mesh_graphormer/__init__.py
@@ -0,0 +1,48 @@
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3, custom_hf_download, MESH_GRAPHORMER_MODEL_NAME
+from custom_controlnet_aux.mesh_graphormer.pipeline import MeshGraphormerMediapipe, args
+import random, torch
+
+def set_seed(seed, n_gpu):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if n_gpu > 0:
+ torch.cuda.manual_seed_all(seed)
+
+class MeshGraphormerDetector:
+ def __init__(self, pipeline):
+ self.pipeline = pipeline
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=MESH_GRAPHORMER_MODEL_NAME, filename="graphormer_hand_state_dict.bin", hrnet_filename="hrnetv2_w64_imagenet_pretrained.pth", detect_thr=0.6, presence_thr=0.6):
+ args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename)
+ args.hrnet_checkpoint = custom_hf_download(pretrained_model_or_path, hrnet_filename)
+ pipeline = MeshGraphormerMediapipe(args, detect_thr=detect_thr, presence_thr=presence_thr)
+ return cls(pipeline)
+
+ def to(self, device):
+ self.pipeline._model.to(device)
+ self.pipeline.mano_model.to(device)
+ self.pipeline.mano_model.layer.to(device)
+ return self
+
+ def __call__(self, input_image=None, mask_bbox_padding=30, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", seed=88, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ set_seed(seed, 0)
+ depth_map, mask, info = self.pipeline.get_depth(input_image, mask_bbox_padding)
+ if depth_map is None:
+ depth_map = np.zeros_like(input_image)
+ mask = np.zeros_like(input_image)
+
+ #The hand is small
+ depth_map, mask = HWC3(depth_map), HWC3(mask)
+ depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method)
+ depth_map = remove_pad(depth_map)
+ if output_type == "pil":
+ depth_map = Image.fromarray(depth_map)
+ mask = Image.fromarray(mask)
+
+ return depth_map, mask, info
diff --git a/src/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml b/src/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..65d710339c15ac3fc94b91bc54b51196feb433b1
--- /dev/null
+++ b/src/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
@@ -0,0 +1,92 @@
+GPUS: (0,1,2,3)
+LOG_DIR: 'log/'
+DATA_DIR: ''
+OUTPUT_DIR: 'output/'
+WORKERS: 4
+PRINT_FREQ: 1000
+
+MODEL:
+ NAME: cls_hrnet
+ IMAGE_SIZE:
+ - 224
+ - 224
+ EXTRA:
+ STAGE1:
+ NUM_MODULES: 1
+ NUM_RANCHES: 1
+ BLOCK: BOTTLENECK
+ NUM_BLOCKS:
+ - 4
+ NUM_CHANNELS:
+ - 64
+ FUSE_METHOD: SUM
+ STAGE2:
+ NUM_MODULES: 1
+ NUM_BRANCHES: 2
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 64
+ - 128
+ FUSE_METHOD: SUM
+ STAGE3:
+ NUM_MODULES: 4
+ NUM_BRANCHES: 3
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 64
+ - 128
+ - 256
+ FUSE_METHOD: SUM
+ STAGE4:
+ NUM_MODULES: 3
+ NUM_BRANCHES: 4
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 64
+ - 128
+ - 256
+ - 512
+ FUSE_METHOD: SUM
+CUDNN:
+ BENCHMARK: true
+ DETERMINISTIC: false
+ ENABLED: true
+DATASET:
+ DATASET: 'imagenet'
+ DATA_FORMAT: 'jpg'
+ ROOT: 'data/imagenet/'
+ TEST_SET: 'val'
+ TRAIN_SET: 'train'
+TEST:
+ BATCH_SIZE_PER_GPU: 32
+ MODEL_FILE: ''
+TRAIN:
+ BATCH_SIZE_PER_GPU: 32
+ BEGIN_EPOCH: 0
+ END_EPOCH: 100
+ RESUME: true
+ LR_FACTOR: 0.1
+ LR_STEP:
+ - 30
+ - 60
+ - 90
+ OPTIMIZER: sgd
+ LR: 0.05
+ WD: 0.0001
+ MOMENTUM: 0.9
+ NESTEROV: true
+ SHUFFLE: true
+DEBUG:
+ DEBUG: false
diff --git a/src/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py b/src/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..496313ad8cff2337bf6b4bee2467082caa838c96
--- /dev/null
+++ b/src/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py
@@ -0,0 +1,6 @@
+class Preprocessor:
+ def __init__(self) -> None:
+ pass
+
+ def get_depth(self, input_dir, file_name):
+ return
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task b/src/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task
new file mode 100644
index 0000000000000000000000000000000000000000..5ecab741879892d97c2f90bbf03bf55d7213db7c
--- /dev/null
+++ b/src/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbc2a30080c3c557093b5ddfc334698132eb341044ccee322ccf8bcf3607cde1
+size 7819105
diff --git a/src/custom_controlnet_aux/mesh_graphormer/pipeline.py b/src/custom_controlnet_aux/mesh_graphormer/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..325770e45cba516c5d8f4b95a956e1b2ab3952e5
--- /dev/null
+++ b/src/custom_controlnet_aux/mesh_graphormer/pipeline.py
@@ -0,0 +1,472 @@
+import os
+import torch
+import gc
+import numpy as np
+from custom_controlnet_aux.mesh_graphormer.depth_preprocessor import Preprocessor
+
+import torchvision.models as models
+from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
+from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
+from custom_mesh_graphormer.modeling._mano import MANO, Mesh
+from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
+from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
+from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
+from custom_mesh_graphormer.utils.miscellaneous import set_seed
+from argparse import Namespace
+from pathlib import Path
+import cv2
+from torchvision import transforms
+import numpy as np
+import cv2
+from trimesh import Trimesh
+from trimesh.ray.ray_triangle import RayMeshIntersector
+import mediapipe as mp
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+from torchvision import transforms
+from pathlib import Path
+from custom_controlnet_aux.util import custom_hf_download
+import custom_mesh_graphormer
+from comfy.model_management import soft_empty_cache
+from packaging import version
+
+args = Namespace(
+ num_workers=4,
+ img_scale_factor=1,
+ image_file_or_path=os.path.join('', 'MeshGraphormer', 'samples', 'hand'),
+ model_name_or_path=str(Path(custom_mesh_graphormer.__file__).parent / "modeling/bert/bert-base-uncased"),
+ resume_checkpoint=None,
+ output_dir='output/',
+ config_name='',
+ a='hrnet-w64',
+ arch='hrnet-w64',
+ num_hidden_layers=4,
+ hidden_size=-1,
+ num_attention_heads=4,
+ intermediate_size=-1,
+ input_feat_dim='2051,512,128',
+ hidden_feat_dim='1024,256,64',
+ which_gcn='0,0,1',
+ mesh_type='hand',
+ run_eval_only=True,
+ device="cpu",
+ seed=88,
+ hrnet_checkpoint=custom_hf_download("hr16/ControlNet-HandRefiner-pruned", 'hrnetv2_w64_imagenet_pretrained.pth')
+)
+
+#Since mediapipe v0.10.5, the hand category has been correct
+if version.parse(mp.__version__) >= version.parse('0.10.5'):
+ true_hand_category = {"Right": "right", "Left": "left"}
+else:
+ true_hand_category = {"Right": "left", "Left": "right"}
+
+class MeshGraphormerMediapipe(Preprocessor):
+ def __init__(self, args=args, detect_thr=0.6, presence_thr=0.6) -> None:
+ #global logger
+ # Setup CUDA, GPU & distributed training
+ args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
+ os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
+ print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
+
+ #mkdir(args.output_dir)
+ #logger = setup_logger("Graphormer", args.output_dir, get_rank())
+ set_seed(args.seed, args.num_gpus)
+ #logger.info("Using {} GPUs".format(args.num_gpus))
+
+ # Mesh and MANO utils
+ mano_model = MANO().to(args.device)
+ mano_model.layer = mano_model.layer.to(args.device)
+ mesh_sampler = Mesh(device=args.device)
+
+ # Renderer for visualization
+ # renderer = Renderer(faces=mano_model.face)
+
+ # Load pretrained model
+ trans_encoder = []
+
+ input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
+ hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
+ output_feat_dim = input_feat_dim[1:] + [3]
+
+ # which encoder block to have graph convs
+ which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
+
+ if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
+ # if only run eval, load checkpoint
+ #logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
+ _model = torch.load(args.resume_checkpoint)
+
+ else:
+ # init three transformer-encoder blocks in a loop
+ for i in range(len(output_feat_dim)):
+ config_class, model_class = BertConfig, Graphormer
+ config = config_class.from_pretrained(args.config_name if args.config_name \
+ else args.model_name_or_path, attn_implementation="eager")
+
+ config.output_attentions = False
+ config.img_feature_dim = input_feat_dim[i]
+ config.output_feature_dim = output_feat_dim[i]
+ args.hidden_size = hidden_feat_dim[i]
+ args.intermediate_size = int(args.hidden_size*2)
+
+ if which_blk_graph[i]==1:
+ config.graph_conv = True
+ #logger.info("Add Graph Conv")
+ else:
+ config.graph_conv = False
+
+ config.mesh_type = args.mesh_type
+
+ # update model structure if specified in arguments
+ update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
+ for idx, param in enumerate(update_params):
+ arg_param = getattr(args, param)
+ config_param = getattr(config, param)
+ if arg_param > 0 and arg_param != config_param:
+ #logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
+ setattr(config, param, arg_param)
+
+ # init a transformer encoder and append it to a list
+ assert config.hidden_size % config.num_attention_heads == 0
+ model = model_class(config=config)
+ #logger.info("Init model from scratch.")
+ trans_encoder.append(model)
+
+ # create backbone model
+ if args.arch=='hrnet':
+ hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = args.hrnet_checkpoint
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ #logger.info('=> loading hrnet-v2-w40 model')
+ elif args.arch=='hrnet-w64':
+ hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = args.hrnet_checkpoint
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ #logger.info('=> loading hrnet-v2-w64 model')
+ else:
+ print("=> using pre-trained model '{}'".format(args.arch))
+ backbone = models.__dict__[args.arch](pretrained=True)
+ # remove the last fc layer
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
+
+ trans_encoder = torch.nn.Sequential(*trans_encoder)
+ total_params = sum(p.numel() for p in trans_encoder.parameters())
+ #logger.info('Graphormer encoders total parameters: {}'.format(total_params))
+ backbone_total_params = sum(p.numel() for p in backbone.parameters())
+ #logger.info('Backbone total parameters: {}'.format(backbone_total_params))
+
+ # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
+ _model = Graphormer_Network(args, config, backbone, trans_encoder)
+
+ if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
+ # for fine-tuning or resume training or inference, load weights from checkpoint
+ #logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
+ # workaround approach to load sparse tensor in graph conv.
+ state_dict = torch.load(args.resume_checkpoint)
+ _model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ gc.collect()
+ soft_empty_cache()
+
+ # update configs to enable attention outputs
+ setattr(_model.trans_encoder[-1].config,'output_attentions', True)
+ setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
+ _model.trans_encoder[-1].bert.encoder.output_attentions = True
+ _model.trans_encoder[-1].bert.encoder.output_hidden_states = True
+ for iter_layer in range(4):
+ _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
+ for inter_block in range(3):
+ setattr(_model.trans_encoder[-1].config,'device', args.device)
+
+ _model.to(args.device)
+ self._model = _model
+ self.mano_model = mano_model
+ self.mesh_sampler = mesh_sampler
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])])
+ #Fix File loading is not yet supported on Windows
+ with open(str( Path(__file__).parent / "hand_landmarker.task" ), 'rb') as file:
+ model_data = file.read()
+ base_options = python.BaseOptions(model_asset_buffer=model_data)
+ options = vision.HandLandmarkerOptions(base_options=base_options,
+ min_hand_detection_confidence=detect_thr,
+ min_hand_presence_confidence=presence_thr,
+ min_tracking_confidence=0.6,
+ num_hands=2)
+
+ self.detector = vision.HandLandmarker.create_from_options(options)
+
+
+ def get_rays(self, W, H, fx, fy, cx, cy, c2w_t, center_pixels): # rot = I
+
+ j, i = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32))
+ if center_pixels:
+ i = i.copy() + 0.5
+ j = j.copy() + 0.5
+
+ directions = np.stack([(i - cx) / fx, (j - cy) / fy, np.ones_like(i)], -1)
+ directions /= np.linalg.norm(directions, axis=-1, keepdims=True)
+
+ rays_o = np.expand_dims(c2w_t,0).repeat(H*W, 0)
+
+ rays_d = directions # (H, W, 3)
+ rays_d = (rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)).reshape(-1,3)
+
+ return rays_o, rays_d
+
+ def get_mask_bounding_box(self, extrema, H, W, padding=30, dynamic_resize=0.15):
+ x_min, x_max, y_min, y_max = extrema
+ bb_xpad = max(int((x_max - x_min + 1) * dynamic_resize), padding)
+ bb_ypad = max(int((y_max - y_min + 1) * dynamic_resize), padding)
+ bbx_min = np.max((x_min - bb_xpad, 0))
+ bbx_max = np.min((x_max + bb_xpad, W-1))
+ bby_min = np.max((y_min - bb_ypad, 0))
+ bby_max = np.min((y_max + bb_ypad, H-1))
+ return bbx_min, bbx_max, bby_min, bby_max
+
+ def run_inference(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
+ global args
+ H, W = int(crop_len), int(crop_len)
+ Graphormer_model.eval()
+ mano.eval()
+ device = next(Graphormer_model.parameters()).device
+ with torch.no_grad():
+ img_tensor = self.transform(img)
+ batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
+ # obtain 2d joints, which are projected from 3d joints of mesh
+ #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
+ #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
+ pred_camera = pred_camera.cpu()
+ pred_vertices = pred_vertices.cpu()
+ mesh = Trimesh(vertices=pred_vertices[0], faces=mano.face)
+ res = crop_len
+ focal_length = 1000 * scale
+ camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
+ pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
+ z_3d_dist = pred_3d_joints_camera[:,2].clone()
+
+ pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
+
+ rays_o, rays_d = self.get_rays(W, H, focal_length, focal_length, W/2, H/2, camera_t, True)
+ coords = np.array(list(np.ndindex(H,W))).reshape(H,W,-1).transpose(1,0,2).reshape(-1,2)
+ intersector = RayMeshIntersector(mesh)
+ points, index_ray, _ = intersector.intersects_location(rays_o, rays_d, multiple_hits=False)
+
+ tri_index = intersector.intersects_first(rays_o, rays_d)
+
+ tri_index = tri_index[index_ray]
+
+ assert len(index_ray) == len(tri_index)
+
+ discriminator = (np.sum(mesh.face_normals[tri_index]* rays_d[index_ray], axis=-1)<= 0)
+ points = points[discriminator] # ray intesects in interior faces, discard them
+
+ if len(points) == 0:
+ return None, None
+ depth = (points + camera_t)[:,-1]
+ index_ray = index_ray[discriminator]
+ pixel_ray = coords[index_ray]
+
+ minval = np.min(depth)
+ maxval = np.max(depth)
+ depthmap = np.zeros([H,W])
+
+ depthmap[pixel_ray[:, 0], pixel_ray[:, 1]] = 1.0 - (0.8 * (depth - minval) / (maxval - minval))
+ depthmap *= 255
+ return depthmap, pred_2d_joints_img_space
+
+
+ def get_depth(self, np_image, padding):
+ info = {}
+
+ # STEP 3: Load the input image.
+ #https://stackoverflow.com/a/76407270
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_image.copy())
+
+ # STEP 4: Detect hand landmarks from the input image.
+ detection_result = self.detector.detect(image)
+
+ handedness_list = detection_result.handedness
+ hand_landmarks_list = detection_result.hand_landmarks
+
+ raw_image = image.numpy_view()
+ H, W, C = raw_image.shape
+
+
+ # HANDLANDMARKS CAN BE EMPTY, HANDLE THIS!
+ if len(hand_landmarks_list) == 0:
+ return None, None, None
+ raw_image = raw_image[:, :, :3]
+
+ padded_image = np.zeros((H*2, W*2, 3))
+ padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = raw_image
+
+ hand_landmarks_list, handedness_list = zip(
+ *sorted(
+ zip(hand_landmarks_list, handedness_list), key=lambda x: x[0][9].z, reverse=True
+ )
+ )
+
+ padded_depthmap = np.zeros((H*2, W*2))
+ mask = np.zeros((H, W))
+ crop_boxes = []
+ #bboxes = []
+ groundtruth_2d_keypoints = []
+ hands = []
+ depth_failure = False
+ crop_lens = []
+ abs_boxes = []
+
+ for idx in range(len(hand_landmarks_list)):
+ hand = true_hand_category[handedness_list[idx][0].category_name]
+ hands.append(hand)
+ hand_landmarks = hand_landmarks_list[idx]
+ handedness = handedness_list[idx]
+ height, width, _ = raw_image.shape
+ x_coordinates = [landmark.x for landmark in hand_landmarks]
+ y_coordinates = [landmark.y for landmark in hand_landmarks]
+
+ # x_min, x_max, y_min, y_max: extrema from mediapipe keypoint detection
+ x_min = int(min(x_coordinates) * width)
+ x_max = int(max(x_coordinates) * width)
+ x_c = (x_min + x_max)//2
+ y_min = int(min(y_coordinates) * height)
+ y_max = int(max(y_coordinates) * height)
+ y_c = (y_min + y_max)//2
+ abs_boxes.append([x_min, x_max, y_min, y_max])
+
+ #if x_max - x_min < 60 or y_max - y_min < 60:
+ # continue
+
+ crop_len = (max(x_max - x_min, y_max - y_min) * 1.6) //2 * 2
+
+ # crop_x_min, crop_x_max, crop_y_min, crop_y_max: bounding box for mesh reconstruction
+ crop_x_min = int(x_c - (crop_len/2 - 1) + W/2)
+ crop_x_max = int(x_c + crop_len/2 + W/2)
+ crop_y_min = int(y_c - (crop_len/2 - 1) + H/2)
+ crop_y_max = int(y_c + crop_len/2 + H/2)
+
+ cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
+ crop_boxes.append([crop_y_min, crop_y_max, crop_x_min, crop_x_max])
+ crop_lens.append(crop_len)
+ if hand == "left":
+ cropped = cv2.flip(cropped, 1)
+
+ if crop_len < 224:
+ graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
+ else:
+ graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
+ scale = crop_len/224
+ cropped_depthmap, pred_2d_keypoints = self.run_inference(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, int(crop_len))
+
+ if cropped_depthmap is None:
+ depth_failure = True
+ break
+ #keypoints_image_space = pred_2d_keypoints * (crop_y_max - crop_y_min + 1)/224
+ groundtruth_2d_keypoints.append(pred_2d_keypoints)
+
+ if hand == "left":
+ cropped_depthmap = cv2.flip(cropped_depthmap, 1)
+ resized_cropped_depthmap = cv2.resize(cropped_depthmap, (int(crop_len), int(crop_len)), interpolation=cv2.INTER_LINEAR)
+ nonzero_y, nonzero_x = (resized_cropped_depthmap != 0).nonzero()
+ if len(nonzero_y) == 0 or len(nonzero_x) == 0:
+ depth_failure = True
+ break
+ padded_depthmap[crop_y_min+nonzero_y, crop_x_min+nonzero_x] = resized_cropped_depthmap[nonzero_y, nonzero_x]
+
+ # nonzero stands for nonzero value on the depth map
+ # coordinates of nonzero depth pixels in original image space
+ original_nonzero_x = crop_x_min+nonzero_x - int(W/2)
+ original_nonzero_y = crop_y_min+nonzero_y - int(H/2)
+
+ nonzerox_min = min(np.min(original_nonzero_x), x_min)
+ nonzerox_max = max(np.max(original_nonzero_x), x_max)
+ nonzeroy_min = min(np.min(original_nonzero_y), y_min)
+ nonzeroy_max = max(np.max(original_nonzero_y), y_max)
+
+ bbx_min, bbx_max, bby_min, bby_max = self.get_mask_bounding_box((nonzerox_min, nonzerox_max, nonzeroy_min, nonzeroy_max), H, W, padding)
+ mask[bby_min:bby_max+1, bbx_min:bbx_max+1] = 1.0
+ #bboxes.append([int(bbx_min), int(bbx_max), int(bby_min), int(bby_max)])
+ if depth_failure:
+ #print("cannot detect normal hands")
+ return None, None, None
+ depthmap = padded_depthmap[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)].astype(np.uint8)
+ mask = (255.0 * mask).astype(np.uint8)
+ info["groundtruth_2d_keypoints"] = groundtruth_2d_keypoints
+ info["hands"] = hands
+ info["crop_boxes"] = crop_boxes
+ info["crop_lens"] = crop_lens
+ info["abs_boxes"] = abs_boxes
+ return depthmap, mask, info
+
+ def get_keypoints(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
+ global args
+ H, W = int(crop_len), int(crop_len)
+ Graphormer_model.eval()
+ mano.eval()
+ device = next(Graphormer_model.parameters()).device
+ with torch.no_grad():
+ img_tensor = self.transform(img)
+ #print(img_tensor)
+ batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
+ # obtain 2d joints, which are projected from 3d joints of mesh
+ #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
+ #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
+ pred_camera = pred_camera.cpu()
+ pred_vertices = pred_vertices.cpu()
+ #
+ res = crop_len
+ focal_length = 1000 * scale
+ camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
+ pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
+ z_3d_dist = pred_3d_joints_camera[:,2].clone()
+ pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
+
+ return pred_2d_joints_img_space
+
+
+ def eval_mpjpe(self, sample, info):
+ H, W, C = sample.shape
+ padded_image = np.zeros((H*2, W*2, 3))
+ padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = sample
+ crop_boxes = info["crop_boxes"]
+ hands = info["hands"]
+ groundtruth_2d_keypoints = info["groundtruth_2d_keypoints"]
+ crop_lens = info["crop_lens"]
+ pjpe = 0
+ for i in range(len(crop_boxes)):#box in crop_boxes:
+ crop_y_min, crop_y_max, crop_x_min, crop_x_max = crop_boxes[i]
+ cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
+ hand = hands[i]
+ if hand == "left":
+ cropped = cv2.flip(cropped, 1)
+ crop_len = crop_lens[i]
+ scale = crop_len/224
+ if crop_len < 224:
+ graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
+ else:
+ graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
+ generated_keypoint = self.get_keypoints(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, crop_len)
+ #generated_keypoint = generated_keypoint * ((crop_y_max - crop_y_min + 1)/224)
+ pjpe += np.sum(np.sqrt(np.sum(((generated_keypoint - groundtruth_2d_keypoints[i]) ** 2).numpy(), axis=1)))
+ pass
+ mpjpe = pjpe/(len(crop_boxes) * 21)
+ return mpjpe
diff --git a/src/custom_controlnet_aux/metric3d/__init__.py b/src/custom_controlnet_aux/metric3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a8ed68d69d2fa9fdd54d2fb858cd203c5e1a7b1
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/__init__.py
@@ -0,0 +1,124 @@
+
+import torch
+import os
+from pathlib import Path
+
+CODE_SPACE=Path(os.path.dirname(os.path.abspath(__file__)))
+
+from custom_mmpkg.custom_mmcv.utils import Config, DictAction
+from .mono.model.monodepth_model import get_configured_monodepth_model
+from .mono.utils.running import load_ckpt
+from .mono.utils.do_test import transform_test_data_scalecano, get_prediction
+import numpy as np
+from .mono.utils.visualization import vis_surface_normal
+from einops import repeat
+from PIL import Image
+from ..util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, METRIC3D_MODEL_NAME
+import re
+import matplotlib.pyplot as plt
+
+def load_model(model_selection, model_path):
+ if model_selection == "vit-small":
+ cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.small.py')
+ elif model_selection == "vit-large":
+ cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.large.py')
+ elif model_selection == "vit-giant2":
+ cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.giant2.py')
+ else:
+ raise NotImplementedError(f"metric3d model: {model_selection}")
+ model = get_configured_monodepth_model(cfg, )
+ model, _, _, _ = load_ckpt(model_path, model, strict_match=False)
+ model.eval()
+ model = model
+ return model, cfg
+
+def gray_to_colormap(img, cmap='rainbow'):
+ """
+ Transfer gray map to matplotlib colormap
+ """
+ assert img.ndim == 2
+
+ img[img<0] = 0
+ mask_invalid = img < 1e-10
+ img = img / (img.max() + 1e-8)
+ norm = plt.Normalize(vmin=0, vmax=1.1) # Use plt.Normalize instead of matplotlib.colors.Normalize
+ cmap_m = plt.get_cmap(cmap) # Access the colormap directly from plt
+ map = plt.cm.ScalarMappable(norm=norm, cmap=cmap_m)
+ colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
+ colormap[mask_invalid] = 0
+ return colormap
+
+def predict_depth_normal(model, cfg, np_img, fx=1000.0, fy=1000.0, state_cache={}):
+ intrinsic = [fx, fy, np_img.shape[1]/2, np_img.shape[0]/2]
+ rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(np_img, intrinsic, cfg.data_basic, device=next(model.parameters()).device)
+
+ with torch.no_grad():
+ pred_depth, confidence, output = get_prediction(
+ model = model,
+ input = rgb_input.unsqueeze(0),
+ cam_model = cam_models_stacks,
+ pad_info = pad,
+ scale_info = label_scale_factor,
+ gt_depth = None,
+ normalize_scale = cfg.data_basic.depth_range[1],
+ ori_shape=[np_img.shape[0], np_img.shape[1]],
+ )
+
+ pred_normal = output['normal_out_list'][0][:, :3, :, :]
+ H, W = pred_normal.shape[2:]
+ pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ pred_depth = pred_depth[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3] ]
+
+ pred_depth = pred_depth.squeeze().cpu().numpy()
+ pred_color = gray_to_colormap(pred_depth, 'Greys')
+
+ pred_normal = torch.nn.functional.interpolate(pred_normal, [np_img.shape[0], np_img.shape[1]], mode='bilinear').squeeze()
+ pred_normal = pred_normal.permute(1,2,0)
+ pred_color_normal = vis_surface_normal(pred_normal)
+ pred_normal = pred_normal.cpu().numpy()
+
+ # Storing depth and normal map in state for potential 3D reconstruction
+ state_cache['depth'] = pred_depth
+ state_cache['normal'] = pred_normal
+ state_cache['img'] = np_img
+ state_cache['intrinsic'] = intrinsic
+ state_cache['confidence'] = confidence
+
+ return pred_color, pred_color_normal, state_cache
+
+class Metric3DDetector:
+ def __init__(self, model, cfg):
+ self.model = model
+ self.cfg = cfg
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=METRIC3D_MODEL_NAME, filename="metric_depth_vit_small_800k.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ backbone = re.findall(r"metric_depth_vit_(\w+)_", model_path)[0]
+ model, cfg = load_model(f'vit-{backbone}', model_path)
+ return cls(model, cfg)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, fx=1000, fy=1000, output_type=None, upscale_method="INTER_CUBIC", depth_and_normal=True, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+
+ depth_map, normal_map, _ = predict_depth_normal(self.model, self.cfg, input_image, fx=fx, fy=fy)
+ # ControlNet uses inverse depth and normal
+ depth_map, normal_map = depth_map, 255 - normal_map
+ depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method)
+ normal_map, _ = resize_image_with_pad(normal_map, detect_resolution, upscale_method)
+ depth_map, normal_map = remove_pad(depth_map), remove_pad(normal_map)
+
+ if output_type == "pil":
+ depth_map = Image.fromarray(depth_map)
+ normal_map = Image.fromarray(normal_map)
+
+ if depth_and_normal:
+ return depth_map, normal_map
+ else:
+ return depth_map
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b91c80284d6db3df3017ec636f18198e42dc08
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (544, 1216),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd9156b7f2f0921fb01b1adaf9a2a7447332d6e
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (512, 1088),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..6601f5cdfad07c5fad8b89fbf959e67039126dfa
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (480, 1216),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4a1238b6eb7ffcc21a237e748bbd7ed75bdf5aa
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py
@@ -0,0 +1,32 @@
+_base_=[
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+
+max_value = 200
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, max_value),
+ crop_size = (616, 1064), # %28 = 0
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064)
+)
+
+batchsize_per_gpu = 1
+thread_per_gpu = 1
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e81c2ea343ea7bcd9980c76ee18e439d68c04e8
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py
@@ -0,0 +1,32 @@
+_base_=[
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+
+max_value = 200
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, max_value),
+ crop_size = (616, 1064), # %28 = 0
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064)
+)
+
+batchsize_per_gpu = 1
+thread_per_gpu = 1
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0a169d49b38ab72b5f0c6b2eeb6e891431ad279
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py
@@ -0,0 +1,32 @@
+_base_=[
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+
+max_value = 200
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, max_value),
+ crop_size = (616, 1064), # %28 = 0
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064)
+)
+
+batchsize_per_gpu = 1
+thread_per_gpu = 1
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/__init__.py b/src/custom_controlnet_aux/metric3d/mono/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/__init__.py
@@ -0,0 +1 @@
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py
new file mode 100644
index 0000000000000000000000000000000000000000..35f3844f24191b6b9452e136ea3205b7622466d7
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py
@@ -0,0 +1,13 @@
+# canonical camera setting and basic data setting
+# we set it same as the E300 camera (crop version)
+#
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1196.0,
+ ),
+ depth_range=(0.9, 150),
+ depth_normalize=(0.006, 1.001),
+ crop_size = (512, 960),
+ clip_depth_range=(0.9, 150),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py
new file mode 100644
index 0000000000000000000000000000000000000000..b554444e9b75b4519b862e726890dcf7859be0ec
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py
@@ -0,0 +1,12 @@
+# canonical camera setting and basic data setting
+#
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1196.0,
+ ),
+ depth_range=(0.9, 150),
+ depth_normalize=(0.006, 1.001),
+ crop_size = (512, 960),
+ clip_depth_range=(0.9, 150),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..a690b491bf50aad5c2fd7e9ac387609123a4594a
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py
@@ -0,0 +1,4 @@
+
+load_from = None
+cudnn_benchmark = True
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel']
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a22f7e1b53ca154bfae1672e6ee3b52028039b9
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py
@@ -0,0 +1,16 @@
+#_base_ = ['./_model_base_.py',]
+
+#'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth'
+model = dict(
+ #type='EncoderDecoderAuxi',
+ backbone=dict(
+ type='convnext_large',
+ pretrained=True,
+ in_22k=True,
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth',
+ prefix='backbones.',
+ out_channels=[192, 384, 768, 1536]),
+ )
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c1ebc96ceaa32ad9310d3b84d55d252be843c46
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_giant2_reg',
+ prefix='backbones.',
+ out_channels=[1536, 1536, 1536, 1536],
+ drop_path_rate = 0.0),
+ )
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..843178ed6e61d74070b971f01148f87fdf2a62cf
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_large',
+ prefix='backbones.',
+ out_channels=[1024, 1024, 1024, 1024],
+ drop_path_rate = 0.0),
+ )
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e96747d459d42df299f8a6a1e14044a0e56164
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_large_reg',
+ prefix='backbones.',
+ out_channels=[1024, 1024, 1024, 1024],
+ drop_path_rate = 0.0),
+ )
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c8bd97dccb9cdee7517250f40e01bb3124144e6
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_small_reg',
+ prefix='backbones.',
+ out_channels=[384, 384, 384, 384],
+ drop_path_rate = 0.0),
+ )
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f262288c49e7ffccb6174b09b0daf80ff79dd684
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py
@@ -0,0 +1,10 @@
+# model settings
+_base_ = ['../backbones/convnext_large.py',]
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='HourglassDecoder',
+ in_channels=[192, 384, 768, 1536],
+ decoder_channel=[128, 128, 256, 512],
+ prefix='decode_heads.'),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..73702d298c05979bcdf013e9c30ec56f4e36665b
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_giant2_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1536, 1536, 1536, 1536],
+ use_cls_token=True,
+ feature_channels = [384, 768, 1536, 1536], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [192, 384, 768, 1536, 1536], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[192, 192, 192, 192], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd69efefab2c03de435996c6b7b65ff941db1e5d
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py
@@ -0,0 +1,20 @@
+# model settings
+_base_ = ['../backbones/dino_vit_large.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1024, 1024, 1024, 1024],
+ use_cls_token=True,
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=12,
+ slow_fast_gru=True,
+ corr_radius=4,
+ corr_levels=4,
+ prefix='decode_heads.'),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ab6dc090e9cdb840d84fab10587becb536dbb8
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_large_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1024, 1024, 1024, 1024],
+ use_cls_token=True,
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..19466c191e9f2a83903e55ca4fc0827d9a11bcb9
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_small_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[384, 384, 384, 384],
+ use_cls_token=True,
+ feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/__init__.py b/src/custom_controlnet_aux/metric3d/mono/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1ea3d3e3b880e28ef880083b3c79e3b00cd119
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/__init__.py
@@ -0,0 +1,5 @@
+from .monodepth_model import DepthModel
+# from .__base_model__ import BaseDepthModel
+
+
+__all__ = ['DepthModel', 'BaseDepthModel']
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dada35e3f61ad98ffc069f25bcba0078358adfe
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py
@@ -0,0 +1,260 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.layers import trunc_normal_, DropPath
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans=3, num_classes=1000,
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
+ layer_scale_init_value=1e-6, head_init_scale=1.,
+ **kwargs,):
+ super().__init__()
+
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ #self.head = nn.Linear(dims[-1], num_classes)
+
+ self.apply(self._init_weights)
+ #self.head.weight.data.mul_(head_init_scale)
+ #self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ features = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ features.append(x)
+ return features # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x):
+ #x = self.forward_features(x)
+ #x = self.head(x)
+ features = self.forward_features(x)
+ return features
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+model_urls = {
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
+}
+
+def convnext_tiny(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_small(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_base(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_large(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_xlarge(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
+ if pretrained:
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+if __name__ == '__main__':
+ import torch
+ model = convnext_base(True, in_22k=False).cuda()
+
+ rgb = torch.rand((2, 3, 256, 256)).cuda()
+ out = model(rgb)
+ print(len(out))
+ for i, ft in enumerate(out):
+ print(i, ft.shape)
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py
new file mode 100644
index 0000000000000000000000000000000000000000..b86c6b71e40817616a0b421b9ac9c9b80dde5f0b
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py
@@ -0,0 +1,1489 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+class ConvBlock(nn.Module):
+ def __init__(self, channels):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm1 = nn.BatchNorm2d(channels)
+ self.conv2 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm2 = nn.BatchNorm2d(channels)
+
+ def forward(self, x):
+
+ out = self.norm1(x)
+ out = self.act(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=37,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.window_size = window_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ features = []
+ for blk in self.blocks:
+ x = blk(x)
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x)
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+class PosConv(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim),
+ )
+ self.stride = stride
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+ #def no_weight_decay(self):
+ #return ['proj.%d.weight' % i for i in range(4)]
+
+class DinoWindowVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=7,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+
+ self.pos_conv = PosConv(self.embed_dim, self.embed_dim)
+
+ self.window_size = window_size
+ #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)])
+ #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)])
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.nh = -1
+ self.nw = -1
+ try:
+ H = cfg.data_basic['crop_size'][0]
+ W = cfg.data_basic['crop_size'][1]
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ self.nh = (H + pad_h) // self.patch_size
+ self.nw = (W + pad_w) // self.patch_size
+ self.prepare_attn_bias((self.nh, self.nw))
+ except:
+ pass
+ self.init_weights()
+
+ self.total_step = 10000 # For PE -> GPE transfer
+ self.start_step = 2000
+ self.current_step = 20000
+
+ def init_weights(self):
+ #trunc_normal_(self.pos_embed, std=0.02)
+ #nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+ for i in range(4):
+ try:
+ nn.init.constant_(self.conv_block[i].conv2.weight, 0.0)
+ except:
+ pass
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ #npatch = x.shape[1] - 1
+ #N = self.pos_embed.shape[1] - 1
+ npatch = x.shape[1]
+ N = self.pos_embed.shape[1]
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ #class_pos_embed = pos_embed[:, 0]
+ #patch_pos_embed = pos_embed[:, 1:]
+ patch_pos_embed = pos_embed
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed.to(previous_dtype)
+ #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ if conv_feature == False:
+ B, N, C = x.shape
+ H, W = hw[0], hw[1]
+
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
+ else:
+ B, C, H, W = x.shape
+
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
+
+ windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C)
+
+ #y = torch.cat((x_cls, windows), dim=1)
+ return windows #, (Hp, Wp)
+
+
+ def window_unpartition(self,
+ windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False
+ ) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ H, W = hw
+
+ B = windows.shape[0] // (H * W // window_size // window_size)
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+
+ if conv_feature == False:
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1)
+ else:
+ C = windows.shape[-1]
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
+
+ # if Hp > H or Wp > W:
+ # x = x[:, :H, :W, :].contiguous()
+ return x
+
+ def prepare_tokens_with_masks(self, x, masks=None, step=-1):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if step == -1:
+ step = self.current_step
+ else:
+ self.current_step = step
+
+ if step < self.start_step:
+ coef = 0.0
+ elif step < self.total_step:
+ coef = (step - self.start_step) / (self.total_step - self.start_step)
+ else:
+ coef = 1.0
+
+ x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw))
+
+ return x
+
+ def prepare_attn_bias(self, shape):
+ window_size = self.window_size
+ if window_size <= 0:
+ return
+
+ import xformers.components.attention.attention_patterns as AP
+
+ nh, nw = shape
+ radius = (window_size-1)//2
+ mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+
+ pad = (8 - (nh * nw) % 8)
+ if pad == 8:
+ pad = 0
+ mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous()
+ if pad > 0:
+ mask = mask_pad[:, :-pad].view(nh, nw, nh, nw)
+ else:
+ mask = mask_pad[:, :].view(nh, nw, nh, nw)
+
+ # angle
+ mask[:radius+1, :radius+1, :window_size, :window_size] = True
+ mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+ mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+ mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+
+ # edge
+ mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+ mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+ mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+ mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+
+ mask = mask.view(nh*nw, nh*nw)
+ bias_pad = torch.log(mask_pad)
+ #bias = bias_pad[:, :-pad]
+ self.register_buffer('attn_bias', bias_pad)
+
+ return bias_pad
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None, **kwargs):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ nh = (H+pad_h)//self.patch_size
+ nw = (W+pad_w)//self.patch_size
+
+ if self.window_size > 0:
+ if nh == self.nh and nw == self.nw:
+ attn_bias = self.attn_bias
+ else:
+ attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size))
+ self.nh = nh
+ self.nw = nw
+ attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1)
+ else:
+ attn_bias = None
+
+ x = self.prepare_tokens_with_masks(x, masks)
+ #x = self.patch_embed(x)
+
+ features = []
+ #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+ for blk in self.blocks:
+ x = blk(x, attn_bias)
+ #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x, attn_bias)
+
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x)
+ # if idx + 1 != len(self.blocks[0]):
+ # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # else:
+ # b, c, h, w = x.size()
+ # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c)
+ #features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=14, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=14, **kwargs):
+ model = DinoWindowVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ #del model.norm
+ del model.mask_token
+ return model
+
+ # model = DinoWindowVisionTransformer(
+ # img_size = 518,
+ # patch_size=patch_size,
+ # embed_dim=1024,
+ # depth=24,
+ # num_heads=16,
+ # mlp_ratio=4,
+ # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ # window_size=37,
+ # **kwargs,
+ # )
+
+ # if checkpoint is not None:
+ # with open(checkpoint, "rb") as f:
+ # state_dict = torch.load(f)
+ # try:
+ # model.load_state_dict(state_dict, strict=True)
+ # except:
+ # new_state_dict = {}
+ # for key, value in state_dict.items():
+ # if 'blocks' in key:
+ # key_new = 'blocks.0' + key[len('blocks'):]
+ # else:
+ # key_new = key
+ # if 'pos_embed' in key:
+ # value = value[:, 1:, :]
+ # new_state_dict[key_new] = value
+
+ # model.load_state_dict(new_state_dict, strict=False)
+ # #del model.norm
+ # del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+if __name__ == '__main__':
+ try:
+ from custom_mmpkg.custom_mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 1400, 1680).cuda()
+ model = vit_large(checkpoint="pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
+
+# import time
+# window_size = 37
+# def prepare_window_masks(shape):
+# if window_size <= 0:
+# return None
+# import xformers.components.attention.attention_patterns as AP
+
+# B, nh, nw, _, _ = shape
+# radius = (window_size-1)//2
+# #time0 = time.time()
+# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# # mask = mask.view(nh, nw, nh, nw)
+# # #time1 = time.time() - time0
+
+# # # angle
+# # mask[:radius+1, :radius+1, :window_size, :window_size] = True
+# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+# # time2 = time.time() - time0 - time1
+
+# # # edge
+# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+# # time3 = time.time() - time0 - time2
+# # print(time1, time2, time3)
+
+# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1)
+
+# shape = (1, 55, 55, None, None)
+# mask = prepare_window_masks(shape)
+# # temp = 1
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..346ed10c4689d5f321f8ada9f26d1929bb956182
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py
@@ -0,0 +1,1303 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+import torch.nn.init
+import torch.nn.functional as F
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+# SSF finetuning originally by dongzelian
+def init_ssf_scale_shift(dim):
+ scale = nn.Parameter(torch.ones(dim))
+ shift = nn.Parameter(torch.zeros(dim))
+
+ nn.init.normal_(scale, mean=1, std=.02)
+ nn.init.normal_(shift, std=.02)
+
+ return scale, shift
+
+def ssf_ada(x, scale, shift):
+ assert scale.shape == shift.shape
+ if x.shape[-1] == scale.shape[0]:
+ return x * scale + shift
+ elif x.shape[1] == scale.shape[0]:
+ return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
+ else:
+ raise ValueError('the input tensor shape does not match the shape of the scale factor.')
+
+# LoRA finetuning originally by edwardjhu
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+class LoRALinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ #nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode: bool = True):
+ # def T(w):
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
+ # nn.Linear.train(self, mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # # Make sure that the weights are not merged
+ # if self.r > 0:
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # # Merge the weights and mark it
+ # if self.r > 0:
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ tuning_mode: Optional[str] = None
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
+
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ if self.tuning_mode == 'ssf':
+ x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1)
+
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ out = self.w3(hidden)
+
+ if self.tuning_mode == 'ssf':
+ out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2)
+
+ return out
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ if tuning_mode == 'lora':
+ self.tuning_mode = tuning_mode
+ self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8)
+ else:
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ if tuning_mode == 'lora':
+ self.tuning_mode = tuning_mode
+ self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8)
+ else:
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ if self.tuning_mode == 'ssf':
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ if self.tuning_mode == 'ssf':
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ else:
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.proj_drop(x)
+ return x
+
+XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ tuning_mode=tuning_mode
+ )
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ if self.tuning_mode == 'ssf':
+ return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias))
+ else:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ if self.tuning_mode == 'ssf':
+ return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
+ else:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ multi_output=False,
+ tuning_mode=None,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+ tuning_mode_list = [tuning_mode] * depth
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.multi_output = multi_output
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ tuning_mode=tuning_mode_list[i]
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ #for blk in self.blocks:
+ #x = blk(x)
+
+ #x_norm = self.norm(x)
+ #if self.tuning_mode == 'ssf':
+ #x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1)
+
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ # features = []
+ # features.append(x_norm)
+ # features.append(x_norm)
+ # features.append(x_norm)
+ # features.append(x_norm)
+ # return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+
+ if self.multi_output == False:
+ for blk in self.blocks:
+ x = blk(x)
+ x_norm = self.norm(x)
+ if self.tuning_mode == 'ssf':
+ x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1)
+
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+ else:
+ features = []
+ for blk in self.blocks:
+ for idx, sub_blk in enumerate(blk):
+ x = sub_blk(x)
+ if (idx + 1) % (len(blk) // 4) == 0:
+ features.append(x)
+
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def load_ckpt_dino(checkpoint, model):
+ if checkpoint is not None:
+ try:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ except:
+ print('NO pretrained imagenet ckpt available! Check your path!')
+ del model.mask_token
+ return
+
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ return
+ else:
+ return
+
+
+def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ **kwargs,
+ )
+ return model
+
+
+
+def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ tuning_mode=tuning_mode,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ tuning_mode=tuning_mode,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ tuning_mode=tuning_mode,
+ multi_output=True,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+if __name__ == '__main__':
+ try:
+ from custom_mmpkg.custom_mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 616, 1064).cuda()
+ cfg['tuning_mode'] = 'ssf'
+ #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda()
+ model = vit_large_reg(tuning_mode='ssf').cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py b/src/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36204128dbb323426b0aa19e52674cbdc3a0f860
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py
@@ -0,0 +1,11 @@
+from .ConvNeXt import convnext_xlarge
+from .ConvNeXt import convnext_small
+from .ConvNeXt import convnext_base
+from .ConvNeXt import convnext_large
+from .ConvNeXt import convnext_tiny
+from .ViT_DINO import vit_large
+from .ViT_DINO_reg import vit_small_reg, vit_large_reg, vit_giant2_reg
+
+__all__ = [
+ 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg', 'vit_giant2_reg'
+]
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e084382601e21e6ce5144abbd6a65f563905b659
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py
@@ -0,0 +1,274 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+import torch.nn.functional as F
+
+def compute_depth_expectation(prob, depth_values):
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
+ depth = torch.sum(prob * depth_values, 1)
+ return depth
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(ConvBlock, self).__init__()
+
+ if kernel_size == 3:
+ self.conv = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
+ )
+ elif kernel_size == 1:
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
+
+ self.nonlin = nn.ELU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.nonlin(out)
+ return out
+
+
+class ConvBlock_double(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(ConvBlock_double, self).__init__()
+
+ if kernel_size == 3:
+ self.conv = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
+ )
+ elif kernel_size == 1:
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
+
+ self.nonlin = nn.ELU(inplace=True)
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1)
+ self.nonlin_2 =nn.ELU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.nonlin(out)
+ out = self.conv_2(out)
+ out = self.nonlin_2(out)
+ return out
+
+class DecoderFeature(nn.Module):
+ def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]):
+ super(DecoderFeature, self).__init__()
+ self.num_ch_dec = num_ch_dec
+ self.feat_channels = feat_channels
+
+ self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1)
+ self.upconv_3_1 = ConvBlock_double(
+ self.feat_channels[2] + self.num_ch_dec[3],
+ self.num_ch_dec[3],
+ kernel_size=1)
+
+ self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3)
+ self.upconv_2_1 = ConvBlock_double(
+ self.feat_channels[1] + self.num_ch_dec[2],
+ self.num_ch_dec[2],
+ kernel_size=3)
+
+ self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3)
+ self.upconv_1_1 = ConvBlock_double(
+ self.feat_channels[0] + self.num_ch_dec[1],
+ self.num_ch_dec[1],
+ kernel_size=3)
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
+
+ def forward(self, ref_feature):
+ x = ref_feature[3]
+
+ x = self.upconv_3_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[2]), 1)
+ x = self.upconv_3_1(x)
+
+ x = self.upconv_2_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[1]), 1)
+ x = self.upconv_2_1(x)
+
+ x = self.upconv_1_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[0]), 1)
+ x = self.upconv_1_1(x)
+ return x
+
+
+class UNet(nn.Module):
+ def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'):
+ super(UNet, self).__init__()
+ basic_block = ConvBnReLU
+ num_depth = 128
+
+ self.conv0 = basic_block(inp_ch, num_depth)
+ if channel_mode == 'v0':
+ channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8]
+ elif channel_mode == 'v1':
+ channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth]
+ self.down_sample_times = down_sample_times
+ for i in range(down_sample_times):
+ setattr(
+ self, 'conv_%d' % i,
+ nn.Sequential(
+ basic_block(channels[i], channels[i+1], stride=2),
+ basic_block(channels[i+1], channels[i+1])
+ )
+ )
+ for i in range(down_sample_times-1,-1,-1):
+ setattr(self, 'deconv_%d' % i,
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ channels[i+1],
+ channels[i],
+ kernel_size=3,
+ padding=1,
+ output_padding=1,
+ stride=2,
+ bias=False),
+ nn.BatchNorm2d(channels[i]),
+ nn.ReLU(inplace=True)
+ )
+ )
+ self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0)
+
+ def forward(self, x):
+ features = {}
+ conv0 = self.conv0(x)
+ x = conv0
+ features[0] = conv0
+ for i in range(self.down_sample_times):
+ x = getattr(self, 'conv_%d' % i)(x)
+ features[i+1] = x
+ for i in range(self.down_sample_times-1,-1,-1):
+ x = features[i] + getattr(self, 'deconv_%d' % i)(x)
+ x = self.prob(x)
+ return x
+
+class ConvBnReLU(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
+ super(ConvBnReLU, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=pad,
+ bias=False
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ def forward(self, x):
+ return F.relu(self.bn(self.conv(x)), inplace=True)
+
+
+class HourglassDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(HourglassDecoder, self).__init__()
+ self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048]
+ self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256]
+ self.min_val = cfg.data_basic.depth_normalize[0]
+ self.max_val = cfg.data_basic.depth_normalize[1]
+
+ self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256]
+ self.num_depth_regressor_anchor = 512
+ self.feat_channels = self.inchannels
+ unet_in_channel = self.num_ch_dec[1]
+ unet_out_channel = 256
+
+ self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec)
+ self.conv_out_2 = UNet(inp_ch=unet_in_channel,
+ output_chal=unet_out_channel + 1,
+ down_sample_times=3,
+ channel_mode='v0',
+ )
+
+ self.depth_regressor_2 = nn.Sequential(
+ nn.Conv2d(unet_out_channel,
+ self.num_depth_regressor_anchor,
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.BatchNorm2d(self.num_depth_regressor_anchor),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.num_depth_regressor_anchor,
+ self.num_depth_regressor_anchor,
+ kernel_size=1,
+ )
+ )
+ self.residual_channel = 16
+ self.conv_up_2 = nn.Sequential(
+ nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1),
+ nn.BatchNorm2d(self.residual_channel),
+ nn.ReLU(),
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
+ nn.Upsample(scale_factor=4),
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(self.residual_channel, 1, 1, padding=0),
+ )
+
+ def get_bins(self, bins_num):
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda')
+ depth_bins_vec = torch.exp(depth_bins_vec)
+ return depth_bins_vec
+
+ def register_depth_expectation_anchor(self, bins_num, B):
+ depth_bins_vec = self.get_bins(bins_num)
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
+
+ def upsample(self, x, scale_factor=2):
+ return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
+
+ def regress_depth_2(self, feature_map_d):
+ prob = self.depth_regressor_2(feature_map_d).softmax(dim=1)
+ B = prob.shape[0]
+ if "depth_expectation_anchor" not in self._buffers:
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
+ d = compute_depth_expectation(
+ prob,
+ self.depth_expectation_anchor[:B, ...]
+ ).unsqueeze(1)
+ return d
+
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
+ meshgrid = torch.stack((x, y))
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
+ return meshgrid
+
+ def forward(self, features_mono, **kwargs):
+ '''
+ trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4]
+ inv_intrinsic_pool: list of inverse intrinsic matrix.
+ features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...].
+ '''
+ outputs = {}
+ # get encoder feature of the reference view
+ ref_feat = features_mono
+
+ feature_map_mono = self.decoder_mono(ref_feat)
+ feature_map_mono_pred = self.conv_out_2(feature_map_mono)
+ confidence_map_2 = feature_map_mono_pred[:, -1:, :, :]
+ feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :]
+
+ depth_pred_2 = self.regress_depth_2(feature_map_d_2)
+
+ B, _, H, W = depth_pred_2.shape
+
+ meshgrid = self.create_mesh_grid(H, W, B)
+
+ depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \
+ self.conv_up_2(
+ torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1)
+ )
+ confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4)
+
+ outputs=dict(
+ prediction=depth_pred_mono,
+ confidence=confidence_map_mono,
+ pred_logit=None,
+ )
+ return outputs
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
new file mode 100644
index 0000000000000000000000000000000000000000..7790d3a03a7573d255238aba399c14f74b50507f
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
@@ -0,0 +1,1031 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+import torch.nn.functional as F
+
+# LORA finetuning originally by edwardjhu
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+class LoRALinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ #nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode: bool = True):
+ # def T(w):
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
+ # nn.Linear.train(self, mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # # Make sure that the weights are not merged
+ # if self.r > 0:
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # # Merge the weights and mark it
+ # if self.r > 0:
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+class ConvLoRA(nn.Conv2d, LoRALayer):
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
+ assert isinstance(kernel_size, int)
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ self.merged = False
+
+ def reset_parameters(self):
+ #self.conv.reset_parameters()
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode=True):
+ # super(ConvLoRA, self).train(mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # if self.r > 0:
+ # # Make sure that the weights are not merged
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # if self.r > 0:
+ # # Merge the weights and mark it
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = True
+
+ def forward(self, x):
+ if self.r > 0 and not self.merged:
+ # return self.conv._conv_forward(
+ # x,
+ # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
+ # self.conv.bias
+ # )
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ bias = self.bias
+
+ return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
+ else:
+ return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
+
+class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer):
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
+ nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
+ assert isinstance(kernel_size, int)
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ self.merged = False
+
+ def reset_parameters(self):
+ #self.conv.reset_parameters()
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode=True):
+ # super(ConvTransposeLoRA, self).train(mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # if self.r > 0:
+ # # Make sure that the weights are not merged
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # if self.r > 0:
+ # # Merge the weights and mark it
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = True
+
+ def forward(self, x):
+ if self.r > 0 and not self.merged:
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ bias = self.bias
+ return F.conv_transpose2d(x, weight,
+ bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
+ groups=self.groups, dilation=self.dilation)
+ else:
+ return F.conv_transpose2d(x, self.weight,
+ bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
+ groups=self.groups, dilation=self.dilation)
+ #return self.conv(x)
+
+class Conv2dLoRA(ConvLoRA):
+ def __init__(self, *args, **kwargs):
+ super(Conv2dLoRA, self).__init__(*args, **kwargs)
+
+class ConvTranspose2dLoRA(ConvTransposeLoRA):
+ def __init__(self, *args, **kwargs):
+ super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs)
+
+
+def compute_depth_expectation(prob, depth_values):
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
+ depth = torch.sum(prob * depth_values, 1)
+ return depth
+
+def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+
+# def upflow8(flow, mode='bilinear'):
+# new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def upflow4(flow, mode='bilinear'):
+ new_size = (4 * flow.shape[2], 4 * flow.shape[3])
+ return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def coords_grid(batch, ht, wd):
+ # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+def norm_normalize(norm_out):
+ min_kappa = 0.01
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
+ kappa = F.elu(kappa) + 1.0 + min_kappa
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
+ return final_out
+
+# uncertainty-guided sampling (only used during training)
+@torch.no_grad()
+def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # uncertainty map
+ uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
+ uncertainty_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
+
+ # importance sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ rows_float = rows_int / float(H-1) # 0 to 1.0
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ cols_float = cols_int / float(W-1) # 0 to 1.0
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ point_coords = torch.zeros(B, 1, N, 2)
+ point_coords[:, 0, :, 0] = cols_float # x coord
+ point_coords[:, 0, :, 1] = rows_float # y coord
+ point_coords = point_coords.to(device)
+ return point_coords, rows_int, cols_int
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None):
+ super(FlowHead, self).__init__()
+ self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+
+ self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ depth = self.conv2d(self.relu(self.conv1d(x)))
+ normal = self.conv2n(self.relu(self.conv1n(x)))
+ return torch.cat((depth, normal), dim=1)
+
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None):
+ super(ConvGRU, self).__init__()
+ self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+ self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+ self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+
+ def forward(self, h, cz, cr, cq, *x_list):
+ x = torch.cat(x_list, dim=1)
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid((self.convz(hx) + cz))
+ r = torch.sigmoid((self.convr(hx) + cr))
+ q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))
+
+ # z = torch.sigmoid((self.convz(hx) + cz).float())
+ # r = torch.sigmoid((self.convr(hx) + cr).float())
+ # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())
+
+ h = (1-z) * h + z * q
+ return h
+
+def pool2x(x):
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
+
+def pool4x(x):
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
+
+def interp(x, dest):
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
+ return interpolate_float32(x, dest.shape[2:], **interp_args)
+
+class BasicMultiUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None):
+ super().__init__()
+ self.args = args
+ self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
+ self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+
+ # self.encoder = BasicMotionEncoder(args)
+ # encoder_output_dim = 128 # if there is corr volume
+ encoder_output_dim = 6 # no corr volume
+
+ self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode)
+ self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode)
+ self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode)
+ self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode)
+ factor = 2**self.n_downsample
+
+ self.mask = nn.Sequential(
+ Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0))
+
+ def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
+
+ if iter32:
+ net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
+ if iter16:
+ if self.n_gru_layers > 2:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
+ else:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
+ if iter08:
+ if corr is not None:
+ motion_features = self.encoder(flow, corr)
+ else:
+ motion_features = flow
+ if self.n_gru_layers > 1:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
+ else:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features)
+
+ if not update:
+ return net
+
+ delta_flow = self.flow_head(net[0])
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net[0])
+ return net, mask, delta_flow
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, dim):
+ super(LayerNorm2d, self).__init__(dim)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1).contiguous()
+ x = super(LayerNorm2d, self).forward(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ return x
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'layer':
+ self.norm1 = LayerNorm2d(planes)
+ self.norm2 = LayerNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = LayerNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.Sequential()
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.conv1(y)
+ y = self.norm1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.norm2(y)
+ y = self.relu(y)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ContextFeatureEncoder(nn.Module):
+ '''
+ Encoder features are used to:
+ 1. initialize the hidden state of the update operator
+ 2. and also injected into the GRU during each iteration of the update operator
+ '''
+ def __init__(self, in_dim, output_dim, tuning_mode=None):
+ '''
+ in_dim = [x4, x8, x16, x32]
+ output_dim = [hindden_dims, context_dims]
+ [[x4,x8,x16,x32],[x4,x8,x16,x32]]
+ '''
+ super().__init__()
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs04 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs08 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs16 = nn.ModuleList(output_list)
+
+ # output_list = []
+ # for dim in output_dim:
+ # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1)
+ # output_list.append(conv_out)
+
+ # self.outputs32 = nn.ModuleList(output_list)
+
+ def forward(self, encoder_features):
+ x_4, x_8, x_16, x_32 = encoder_features
+
+ outputs04 = [f(x_4) for f in self.outputs04]
+ outputs08 = [f(x_8) for f in self.outputs08]
+ outputs16 = [f(x_16)for f in self.outputs16]
+ # outputs32 = [f(x_32) for f in self.outputs32]
+
+ return (outputs04, outputs08, outputs16)
+
+class ConvBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, channels, tuning_mode=None):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = Conv2dLoRA(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+ self.conv2 = Conv2dLoRA(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+
+ def forward(self, x):
+ out = self.act(x)
+ out = self.conv1(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+class FuseBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None):
+ super(FuseBlock, self).__init__()
+
+ self.fuse = fuse
+ self.scale_factor = scale_factor
+ self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode)
+ if self.fuse:
+ self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode)
+
+ self.out_conv = Conv2dLoRA(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+ self.upsample = upsample
+
+ def forward(self, x1, x2=None):
+ if x2 is not None:
+ x2 = self.way_branch(x2)
+ x1 = x1 + x2
+
+ out = self.way_trunk(x1)
+
+ if self.upsample:
+ out = interpolate_float32(
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
+ )
+ out = self.out_conv(out)
+ return out
+
+class Readout(nn.Module):
+ # From DPT
+ def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(Readout, self).__init__()
+ self.use_cls_token = use_cls_token
+ if self.use_cls_token == True:
+ self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0)
+ self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0)
+ self.act = nn.GELU()
+ else:
+ self.project = nn.Identity()
+
+ def forward(self, x):
+
+ if self.use_cls_token == True:
+ x_patch = self.project_patch(x[0])
+ x_learn = self.project_learn(x[1])
+ x_learn = x_learn.expand_as(x_patch).contiguous()
+ features = x_patch + x_learn
+ return self.act(features)
+ else:
+ return self.project(x)
+
+class Token2Feature(nn.Module):
+ # From DPT
+ def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(Token2Feature, self).__init__()
+ self.scale_factor = scale_factor
+ self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ if scale_factor > 1 and isinstance(scale_factor, int):
+ self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ padding=0,
+ )
+
+ elif scale_factor > 1:
+ self.sample = nn.Sequential(
+ # Upsample2(upscale=scale_factor),
+ # nn.Upsample(scale_factor=scale_factor),
+ Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+
+ elif scale_factor < 1:
+ scale_factor = int(1.0 / scale_factor)
+ self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor+1,
+ stride=scale_factor,
+ padding=1,
+ )
+
+ else:
+ self.sample = nn.Identity()
+
+ def forward(self, x):
+ x = self.readoper(x)
+ #if use_cls_token == True:
+ x = x.permute(0, 3, 1, 2).contiguous()
+ if isinstance(self.scale_factor, float):
+ x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
+ x = self.sample(x)
+ return x
+
+class EncoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(EncoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+
+ def forward(self, ref_feature):
+ x = self.read_3(ref_feature[3]) # 1/14
+ x2 = self.read_2(ref_feature[2]) # 1/14
+ x1 = self.read_1(ref_feature[1]) # 1/7
+ x0 = self.read_0(ref_feature[0]) # 1/4
+
+ return x, x2, x1, x0
+
+class DecoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None):
+ super(DecoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.upconv_3 = FuseBlock(
+ self.num_ch_dec[4],
+ self.num_ch_dec[3],
+ fuse=False, upsample=False, tuning_mode=tuning_mode)
+
+ self.upconv_2 = FuseBlock(
+ self.num_ch_dec[3],
+ self.num_ch_dec[2],
+ tuning_mode=tuning_mode)
+
+ self.upconv_1 = FuseBlock(
+ self.num_ch_dec[2],
+ self.num_ch_dec[1] + 2,
+ scale_factor=7/4,
+ tuning_mode=tuning_mode)
+
+ # self.upconv_0 = FuseBlock(
+ # self.num_ch_dec[1],
+ # self.num_ch_dec[0] + 1,
+ # )
+
+ def forward(self, ref_feature):
+ x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4
+
+ x = self.upconv_3(x) # 1/14
+ x = self.upconv_2(x, x2) # 1/7
+ x = self.upconv_1(x, x1) # 1/4
+ # x = self.upconv_0(x, x0) # 4/7
+ return x
+
+class RAFTDepthNormalDPT5(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
+ self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
+ self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
+ self.use_cls_token = cfg.model.decode_head.use_cls_token
+ self.up_scale = cfg.model.decode_head.up_scale
+ self.num_register_tokens = cfg.model.decode_head.num_register_tokens
+ self.min_val = cfg.data_basic.depth_normalize[0]
+ self.max_val = cfg.data_basic.depth_normalize[1]
+ self.regress_scale = 100.0\
+
+ try:
+ tuning_mode = cfg.model.decode_head.tuning_mode
+ except:
+ tuning_mode = None
+ self.tuning_mode = tuning_mode
+
+ self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
+ self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
+ self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+ self.iters = cfg.model.decode_head.iters # 22
+ self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True
+
+ self.num_depth_regressor_anchor = 256 # 512
+ self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
+ self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode)
+ self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode)
+ self.depth_regressor = nn.Sequential(
+ Conv2dLoRA(self.used_res_channel,
+ self.num_depth_regressor_anchor,
+ kernel_size=3,
+ padding=1, r = 8 if tuning_mode == 'lora' else 0),
+ # nn.BatchNorm2d(self.num_depth_regressor_anchor),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(self.num_depth_regressor_anchor,
+ self.num_depth_regressor_anchor,
+ kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
+ )
+ self.normal_predictor = nn.Sequential(
+ Conv2dLoRA(self.used_res_channel,
+ 128,
+ kernel_size=3,
+ padding=1, r = 8 if tuning_mode == 'lora' else 0,),
+ # nn.BatchNorm2d(128),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
+ )
+
+ self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode)
+ self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)])
+ self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def get_bins(self, bins_num):
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device=next(self.parameters()).device)
+ depth_bins_vec = torch.exp(depth_bins_vec)
+ return depth_bins_vec
+
+ def register_depth_expectation_anchor(self, bins_num, B):
+ depth_bins_vec = self.get_bins(bins_num)
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
+
+ def clamp(self, x):
+ y = self.relu(x - self.min_val) + self.min_val
+ y = self.max_val - self.relu(self.max_val - y)
+ return y
+
+ def regress_depth(self, feature_map_d):
+ prob_feature = self.depth_regressor(feature_map_d)
+ prob = prob_feature.softmax(dim=1)
+ #prob = prob_feature.float().softmax(dim=1)
+
+ ## Error logging
+ if torch.isnan(prob).any():
+ print('prob_feat_nan!!!')
+ if torch.isinf(prob).any():
+ print('prob_feat_inf!!!')
+
+ # h = prob[0,:,0,0].cpu().numpy().reshape(-1)
+ # import matplotlib.pyplot as plt
+ # plt.bar(range(len(h)), h)
+ B = prob.shape[0]
+ if "depth_expectation_anchor" not in self._buffers:
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
+ d = compute_depth_expectation(
+ prob,
+ self.depth_expectation_anchor[:B, ...]).unsqueeze(1)
+
+ ## Error logging
+ if torch.isnan(d ).any():
+ print('d_nan!!!')
+ if torch.isinf(d ).any():
+ print('d_inf!!!')
+
+ return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature
+
+ def pred_normal(self, feature_map, confidence):
+ normal_out = self.normal_predictor(feature_map)
+
+ ## Error logging
+ if torch.isnan(normal_out).any():
+ print('norm_nan!!!')
+ if torch.isinf(normal_out).any():
+ print('norm_feat_inf!!!')
+
+ return norm_normalize(torch.cat([normal_out, confidence], dim=1))
+ #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
+
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
+ meshgrid = torch.stack((x, y))
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
+ #self.register_buffer('meshgrid', meshgrid, persistent=False)
+ return meshgrid
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, D, H, W = flow.shape
+ factor = 2 ** self.n_downsample
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
+ mask = torch.softmax(mask, dim=2)
+ #mask = torch.softmax(mask.float(), dim=2)
+
+ #up_flow = F.unfold(factor * flow, [3,3], padding=1)
+ up_flow = F.unfold(flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, D, factor*H, factor*W)
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, _, H, W = img.shape
+
+ coords0 = coords_grid(N, H, W).to(img.device)
+ coords1 = coords_grid(N, H, W).to(img.device)
+
+ return coords0, coords1
+
+ def upsample(self, x, scale_factor=2):
+ """Upsample input tensor by a factor of 2
+ """
+ return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")
+
+ def forward(self, vit_features, **kwargs):
+ ## read vit token to multi-scale features
+ B, H, W, _, _, num_register_tokens = vit_features[1]
+ vit_features = vit_features[0]
+
+ ## Error logging
+ if torch.isnan(vit_features[0]).any():
+ print('vit_feature_nan!!!')
+ if torch.isinf(vit_features[0]).any():
+ print('vit_feature_inf!!!')
+
+ if self.use_cls_token == True:
+ vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
+ ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
+ else:
+ vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
+ encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4
+
+ ## Error logging
+ for en_ft in encoder_features:
+ if torch.isnan(en_ft).any():
+ print('decoder_feature_nan!!!')
+ print(en_ft.shape)
+ if torch.isinf(en_ft).any():
+ print('decoder_feature_inf!!!')
+ print(en_ft.shape)
+
+ ## decode features to init-depth (and confidence)
+ ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth
+
+ ## Error logging
+ if torch.isnan(ref_feat).any():
+ print('ref_feat_nan!!!')
+ if torch.isinf(ref_feat).any():
+ print('ref_feat_inf!!!')
+
+ feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
+ depth_confidence_map = ref_feat[:, -2:-1, :, :]
+ normal_confidence_map = ref_feat[:, -1:, :, :]
+ depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
+ normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal
+
+ depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)
+
+ ## encoder features to context-feature for init-hidden-state and contex-features
+ cnet_list = self.context_feature_encoder(encoder_features[::-1])
+ net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
+ inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features
+
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
+
+ coords0, coords1 = self.initialize_flow(net_list[0])
+ if depth_init is not None:
+ coords1 = coords1 + depth_init
+
+ if self.training:
+ low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
+ init_depth = upflow4(depth_init)
+ flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
+ conf_predictions = [init_depth[:,1:2]]
+ normal_outs = [norm_normalize(init_depth[:,2:].clone())]
+
+ else:
+ flow_predictions = []
+ conf_predictions = []
+ samples_pred_list = []
+ coord_list = []
+ normal_outs = []
+ low_resolution_init = []
+
+ for itr in range(self.iters):
+ # coords1 = coords1.detach()
+ flow = coords1 - coords0
+ if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
+ if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
+ net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # We do not need to upsample or output intermediate results in test_mode
+ #if (not self.training) and itr < self.iters-1:
+ #continue
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = self.upsample(coords1-coords0, 4)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+ # flow_up = self.upsample(coords1-coords0, 4)
+
+ flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
+ conf_predictions.append(flow_up[:,1:2])
+ normal_outs.append(norm_normalize(flow_up[:,2:].clone()))
+
+ outputs=dict(
+ prediction=flow_predictions[-1],
+ predictions_list=flow_predictions,
+ confidence=conf_predictions[-1],
+ confidence_list=conf_predictions,
+ pred_logit=None,
+ # samples_pred_list=samples_pred_list,
+ # coord_list=coord_list,
+ prediction_normal=normal_outs[-1],
+ normal_out_list=normal_outs,
+ low_resolution_init=low_resolution_init,
+ )
+
+ return outputs
+
+
+if __name__ == "__main__":
+ try:
+ from custom_mmpkg.custom_mmcv.utils import Config
+ except:
+ from mmengine import Config
+ cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
+ cfg.model.decode_head.in_channels = [384, 384, 384, 384]
+ cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
+ cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
+ cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
+ cfg.model.decode_head.up_scale = 7
+
+ # cfg.model.decode_head.use_cls_token = True
+ # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]]
+
+ cfg.model.decode_head.use_cls_token = True
+ cfg.model.decode_head.num_register_tokens = 4
+ vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]
+
+ decoder = RAFTDepthNormalDPT5(cfg).cuda()
+ output = decoder(vit_feature)
+ temp = 1
+
+
+
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92381a5fc3dad0ca8009c1ab0a153ce6b107c634
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py
@@ -0,0 +1,4 @@
+from .HourGlassDecoder import HourglassDecoder
+from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
+
+__all__=['HourglassDecoder', 'RAFTDepthNormalDPT5']
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee81691ad00beb3187b2f131b08909caa5b1b16
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+from ...utils.comm import get_func
+
+
+class BaseDepthModel(nn.Module):
+ def __init__(self, cfg, **kwargs) -> None:
+ super(BaseDepthModel, self).__init__()
+ model_type = cfg.model.type
+ # Use relative import approach - get the module dynamically
+ from . import dense_pipeline
+ if model_type == 'DensePredModel':
+ self.depth_model = dense_pipeline.DensePredModel(cfg)
+ else:
+ raise NotImplementedError(f"Model type {model_type} not implemented")
+
+ def forward(self, data):
+ output = self.depth_model(**data)
+
+ return output['prediction'], output['confidence'], output
+
+ def inference(self, data):
+ with torch.no_grad():
+ pred_depth, confidence, _ = self.forward(data)
+ return pred_depth, confidence
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b962a3f858573466e429219c4ad70951b545b637
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py
@@ -0,0 +1,6 @@
+
+from .dense_pipeline import DensePredModel
+from .__base_model__ import BaseDepthModel
+__all__ = [
+ 'DensePredModel', 'BaseDepthModel',
+]
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..509ec9b801edb146c8de10e26a24068101a8d247
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+from ...utils.comm import get_func
+
+class DensePredModel(nn.Module):
+ def __init__(self, cfg) -> None:
+ super(DensePredModel, self).__init__()
+
+ # Use direct imports instead of get_func to avoid module resolution issues
+
+ # Handle different backbone types
+ backbone_type = cfg.model.backbone.type
+ if backbone_type == 'vit_small_reg':
+ from ..backbones.ViT_DINO_reg import vit_small_reg
+ self.encoder = vit_small_reg(**cfg.model.backbone)
+ elif backbone_type == 'vit_large_reg':
+ from ..backbones.ViT_DINO_reg import vit_large_reg
+ self.encoder = vit_large_reg(**cfg.model.backbone)
+ elif backbone_type == 'vit_giant2_reg':
+ from ..backbones.ViT_DINO_reg import vit_giant2_reg
+ self.encoder = vit_giant2_reg(**cfg.model.backbone)
+ elif backbone_type == 'vit_large':
+ from ..backbones.ViT_DINO import vit_large
+ self.encoder = vit_large(**cfg.model.backbone)
+ elif backbone_type == 'convnext_large':
+ from ..backbones.ConvNeXt import convnext_large
+ self.encoder = convnext_large(**cfg.model.backbone)
+ else:
+ raise NotImplementedError(f"Backbone {backbone_type} not implemented")
+
+ # Handle decode head
+ decode_head_type = cfg.model.decode_head.type
+ if decode_head_type == 'RAFTDepthNormalDPT5':
+ from ..decode_heads.RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
+ self.decoder = RAFTDepthNormalDPT5(cfg)
+ else:
+ # For other decode head types, we'd need to check what file they're in
+ raise NotImplementedError(f"Decode head {decode_head_type} not implemented")
+
+ def forward(self, input, **kwargs):
+ # [f_32, f_16, f_8, f_4]
+ features = self.encoder(input)
+ out = self.decoder(features, **kwargs)
+ return out
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py b/src/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b58b7643ee43f84fd4e621e5b3b61b1f3f85564
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+from .model_pipelines.__base_model__ import BaseDepthModel
+
+class DepthModel(BaseDepthModel):
+ def __init__(self, cfg, **kwards):
+ super(DepthModel, self).__init__(cfg)
+ model_type = cfg.model.type
+
+ def inference(self, data):
+ with torch.no_grad():
+ pred_depth, confidence, output_dict = self.forward(data)
+ return pred_depth, confidence, output_dict
+
+def get_monodepth_model(
+ cfg : dict,
+ **kwargs
+ ) -> nn.Module:
+ # config depth model
+ model = DepthModel(cfg, **kwargs)
+ #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
+ assert isinstance(model, nn.Module)
+ return model
+
+def get_configured_monodepth_model(
+ cfg: dict,
+ ) -> nn.Module:
+ """
+ Args:
+ @ configs: configures for the network.
+ @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
+ @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
+ Returns:
+ # model: depth model.
+ """
+ model = get_monodepth_model(cfg)
+ return model
diff --git a/src/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py b/src/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py
new file mode 100644
index 0000000000000000000000000000000000000000..9788d205ef61bf7fca7500ed03252de0b4d70728
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py
@@ -0,0 +1,161 @@
+import os
+import os.path as osp
+import cv2
+import time
+import sys
+CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(CODE_SPACE)
+import argparse
+import custom_mmpkg.custom_mmcv as mmcv
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+try:
+ from custom_mmpkg.custom_mmcv.utils import Config, DictAction
+except:
+ from mmengine import Config, DictAction
+from datetime import timedelta
+import random
+import numpy as np
+from custom_controlnet_aux.metric3d.mono.utils.logger import setup_logger
+import glob
+from custom_controlnet_aux.metric3d.mono.utils.comm import init_env
+from custom_controlnet_aux.metric3d.mono.model.monodepth_model import get_configured_monodepth_model
+from custom_controlnet_aux.metric3d.mono.utils.running import load_ckpt
+from custom_controlnet_aux.metric3d.mono.utils.do_test import do_scalecano_test_with_custom_data
+from custom_controlnet_aux.metric3d.mono.utils.mldb import load_data_info, reset_ckpt_path
+from custom_controlnet_aux.metric3d.mono.utils.custom_data import load_from_annos, load_data
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a segmentor')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--show-dir', help='the dir to save logs and visualization results')
+ parser.add_argument('--load-from', help='the checkpoint file to load weights from')
+ parser.add_argument('--node_rank', type=int, default=0)
+ parser.add_argument('--nnodes', type=int, default=1, help='number of nodes')
+ parser.add_argument('--options', nargs='+', action=DictAction, help='custom options')
+ parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher')
+ parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data')
+ parser.add_argument('--batch_size', default=1, type=int, help='the batch size for inference')
+ args = parser.parse_args()
+ return args
+
+def main(args):
+ os.chdir(CODE_SPACE)
+ cfg = Config.fromfile(args.config)
+
+ if args.options is not None:
+ cfg.merge_from_dict(args.options)
+
+ # show_dir is determined in this priority: CLI > segment in file > filename
+ if args.show_dir is not None:
+ # update configs according to CLI args if args.show_dir is not None
+ cfg.show_dir = args.show_dir
+ else:
+ # use condig filename + timestamp as default show_dir if args.show_dir is None
+ cfg.show_dir = osp.join('./show_dirs',
+ osp.splitext(osp.basename(args.config))[0],
+ args.timestamp)
+
+ # ckpt path
+ if args.load_from is None:
+ raise RuntimeError('Please set model path!')
+ cfg.load_from = args.load_from
+ cfg.batch_size = args.batch_size
+
+ # load data info
+ data_info = {}
+ load_data_info('data_info', data_info=data_info)
+ cfg.mldb_info = data_info
+ # update check point info
+ reset_ckpt_path(cfg.model, data_info)
+
+ # create show dir
+ os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True)
+
+ # init the logger before other steps
+ cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log')
+ logger = setup_logger(cfg.log_file)
+
+ # log some basic info
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # init distributed env dirst, since logger depends on the dist info
+ if args.launcher == 'None':
+ cfg.distributed = False
+ else:
+ cfg.distributed = True
+ init_env(args.launcher, cfg)
+ logger.info(f'Distributed training: {cfg.distributed}')
+
+ # dump config
+ cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config)))
+ test_data_path = args.test_data_path
+ if not os.path.isabs(test_data_path):
+ test_data_path = osp.join(CODE_SPACE, test_data_path)
+
+ if 'json' in test_data_path:
+ test_data = load_from_annos(test_data_path)
+ else:
+ test_data = load_data(args.test_data_path)
+
+ if not cfg.distributed:
+ main_worker(0, cfg, args.launcher, test_data)
+ else:
+ # distributed training
+ if args.launcher == 'ror':
+ local_rank = cfg.dist_params.local_rank
+ main_worker(local_rank, cfg, args.launcher, test_data)
+ else:
+ mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data))
+
+def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list):
+ if cfg.distributed:
+ cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
+ cfg.dist_params.local_rank = local_rank
+
+ if launcher == 'ror':
+ init_torch_process_group(use_hvd=False)
+ else:
+ torch.cuda.set_device(local_rank)
+ default_timeout = timedelta(minutes=30)
+ dist.init_process_group(
+ backend=cfg.dist_params.backend,
+ init_method=cfg.dist_params.dist_url,
+ world_size=cfg.dist_params.world_size,
+ rank=cfg.dist_params.global_rank,
+ timeout=default_timeout)
+
+ logger = setup_logger(cfg.log_file)
+ # build model
+ model = get_configured_monodepth_model(cfg, )
+
+ # config distributed training
+ if cfg.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
+ device_ids=[local_rank],
+ output_device=local_rank,
+ find_unused_parameters=True)
+ else:
+ model = torch.nn.DataParallel(model).cuda()
+
+ # load ckpt
+ model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
+ model.eval()
+
+ do_scalecano_test_with_custom_data(
+ model,
+ cfg,
+ test_data,
+ logger,
+ cfg.distributed,
+ local_rank,
+ cfg.batch_size,
+ )
+
+if __name__ == '__main__':
+ args = parse_args()
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ args.timestamp = timestamp
+ main(args)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/__init__.py b/src/custom_controlnet_aux/metric3d/mono/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py b/src/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ed9fffa7aa7be7eea094280102168993912f44
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py
@@ -0,0 +1,475 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self) -> None:
+ self.reset()
+
+ def reset(self) -> None:
+ self.val = np.longdouble(0.0)
+ self.avg = np.longdouble(0.0)
+ self.sum = np.longdouble(0.0)
+ self.count = np.longdouble(0.0)
+
+ def update(self, val, n: float = 1) -> None:
+ self.val = val
+ self.sum += val
+ self.count += n
+ self.avg = self.sum / (self.count + 1e-6)
+
+class MetricAverageMeter(AverageMeter):
+ """
+ An AverageMeter designed specifically for evaluating segmentation results.
+ """
+ def __init__(self, metrics: list) -> None:
+ """ Initialize object. """
+ # average meters for metrics
+ self.abs_rel = AverageMeter()
+ self.rmse = AverageMeter()
+ self.silog = AverageMeter()
+ self.delta1 = AverageMeter()
+ self.delta2 = AverageMeter()
+ self.delta3 = AverageMeter()
+
+ self.metrics = metrics
+
+ self.consistency = AverageMeter()
+ self.log10 = AverageMeter()
+ self.rmse_log = AverageMeter()
+ self.sq_rel = AverageMeter()
+
+ # normal
+ self.normal_mean = AverageMeter()
+ self.normal_rmse = AverageMeter()
+ self.normal_a1 = AverageMeter()
+ self.normal_a2 = AverageMeter()
+
+ self.normal_median = AverageMeter()
+ self.normal_a3 = AverageMeter()
+ self.normal_a4 = AverageMeter()
+ self.normal_a5 = AverageMeter()
+
+
+ def update_metrics_cpu(self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,):
+ """
+ Update metrics on cpu
+ """
+
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ abs_rel_sum = abs_rel_sum.numpy()
+ valid_pics = valid_pics.numpy()
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # squared relative error
+ sqrel_sum, _ = get_sqrel_err(pred, target, mask)
+ sqrel_sum = sqrel_sum.numpy()
+ self.sq_rel.update(sqrel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ rmse_sum = rmse_sum.numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ log_rmse_sum = log_rmse_sum.numpy()
+ self.rmse.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ log10_sum = log10_sum.numpy()
+ self.rmse.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ silog_sum = silog_sum.numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask)
+ delta1_sum = delta1_sum.numpy()
+ delta2_sum = delta2_sum.numpy()
+ delta3_sum = delta3_sum.numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta1_sum, valid_pics)
+ self.delta3.update(delta1_sum, valid_pics)
+
+
+ def update_metrics_gpu(
+ self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ is_distributed: bool,
+ pred_next: torch.tensor = None,
+ pose_f1_to_f2: torch.tensor = None,
+ intrinsic: torch.tensor = None):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics)
+ abs_rel_sum = abs_rel_sum.cpu().numpy()
+ valid_pics = int(valid_pics)
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(rmse_sum)
+ rmse_sum = rmse_sum.cpu().numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log_rmse_sum)
+ log_rmse_sum = log_rmse_sum.cpu().numpy()
+ self.rmse_log.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log10_sum)
+ log10_sum = log10_sum.cpu().numpy()
+ self.log10.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(silog_sum)
+ silog_sum = silog_sum.cpu().numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum)
+ delta1_sum = delta1_sum.cpu().numpy()
+ delta2_sum = delta2_sum.cpu().numpy()
+ delta3_sum = delta3_sum.cpu().numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta2_sum, valid_pics)
+ self.delta3.update(delta3_sum, valid_pics)
+
+ # video consistency error
+ # consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic)
+ # if is_distributed:
+ # dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps)
+ # consistency_rel_sum = consistency_rel_sum.cpu().numpy()
+ # valid_warps = int(valid_warps)
+ # self.consistency.update(consistency_rel_sum, valid_warps)
+
+ ## for surface normal
+ def update_normal_metrics_gpu(
+ self,
+ pred: torch.Tensor, # (B, 3, H, W)
+ target: torch.Tensor, # (B, 3, H, W)
+ mask: torch.Tensor, # (B, 1, H, W)
+ is_distributed: bool,
+ ):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6
+
+ if valid_pics < 10:
+ return
+
+ mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics
+ normal_error = torch.cosine_similarity(pred, target, dim=1)
+ normal_error = torch.clamp(normal_error, min=-1.0, max=1.0)
+ angle_error = torch.acos(normal_error) * 180.0 / torch.pi
+ angle_error = angle_error[:, None, :, :]
+ angle_error = angle_error[mask]
+ # Calculation error
+ mean_error = angle_error.sum() / valid_pics
+ rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics )
+ median_error = angle_error.median()
+ a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics)
+ a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics)
+
+ a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics)
+ a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics)
+ a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics)
+
+ # if valid_pics > 1e-5:
+ # If the current node gets data with valid normal
+ dist_node_cnt = (valid_pics - 1e-6) / valid_pics
+
+ if is_distributed:
+ dist.all_reduce(dist_node_cnt)
+ dist.all_reduce(mean_error)
+ dist.all_reduce(rmse_error)
+ dist.all_reduce(a1_error)
+ dist.all_reduce(a2_error)
+
+ dist.all_reduce(a3_error)
+ dist.all_reduce(a4_error)
+ dist.all_reduce(a5_error)
+
+ dist_node_cnt = dist_node_cnt.cpu().numpy()
+ self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt)
+ self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt)
+
+ self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt)
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+def get_absrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes absolute relative error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ # Mean Absolute Relative Error
+ rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors
+ abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ abs_err = abs_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(abs_err), valid_pics
+
+def get_sqrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes squared relative error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ # squared Relative Error
+ sq_rel = torch.abs(t_m - p_m) ** 2 / (t_m + 1e-10) # compute errors
+ sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ sqrel_err = sq_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(sqrel_err), valid_pics
+
+def get_log10_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log10 error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ log10_diff = torch.abs(diff_log)
+ log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ log10_err = log10_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(log10_err), valid_pics
+
+def get_rmse_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ square = (t_m - p_m) ** 2
+ rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse = torch.sqrt(rmse_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse), valid_pics
+
+def get_rmse_log_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ square = diff_log ** 2
+ rmse_log_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse_log = torch.sqrt(rmse_log_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse_log), valid_pics
+
+def get_silog_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c]
+ diff_log_square = diff_log ** 2
+ diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) ** 2)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(silog), valid_pics
+
+def get_ratio_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred
+
+ gt_pred = t_m / (p_m + 1e-10)
+ pred_gt = p_m / (t_m + 1e-10)
+ gt_pred = gt_pred.reshape((b, c, -1))
+ pred_gt = pred_gt.reshape((b, c, -1))
+ gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1)
+ ratio_max = torch.amax(gt_pred_gt, axis=1)
+
+ delta_1_sum = torch.sum((ratio_max < 1.25), dim=1) # [b, ]
+ delta_2_sum = torch.sum((ratio_max < 1.25 ** 2), dim=1) # [b, ]
+ delta_3_sum = torch.sum((ratio_max < 1.25 ** 3), dim=1) # [b, ]
+ num = torch.sum(mask.reshape((b, -1)), dim=1) # [b, ]
+
+ delta_1 = delta_1_sum / (num + 1e-10)
+ delta_2 = delta_2_sum / (num + 1e-10)
+ delta_3 = delta_3_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+
+ return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics
+
+
+if __name__ == '__main__':
+ cfg = ['abs_rel', 'delta1']
+ dam = MetricAverageMeter(cfg)
+
+ pred_depth = np.random.random([2, 480, 640])
+ gt_depth = np.random.random([2, 480, 640]) - 0.5
+ intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]]
+
+ pred = torch.from_numpy(pred_depth).cuda()
+ gt = torch.from_numpy(gt_depth).cuda()
+
+ mask = gt > 0
+ dam.update_metrics_gpu(pred, gt, mask, False)
+ eval_error = dam.get_metrics()
+ print(eval_error)
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/comm.py b/src/custom_controlnet_aux/metric3d/mono/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d68a3e95f8b7b99edc1309961b8e5ce40ea9cd
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/comm.py
@@ -0,0 +1,322 @@
+import importlib
+import torch
+import torch.distributed as dist
+from .avg_meter import AverageMeter
+from collections import defaultdict, OrderedDict
+import os
+import socket
+from custom_mmpkg.custom_mmcv.utils import collect_env as collect_base_env
+try:
+ from custom_mmpkg.custom_mmcv.utils import get_git_hash
+except:
+ from mmengine.utils import get_git_hash
+#import mono.mmseg as mmseg
+# import mmseg
+import time
+import datetime
+import logging
+
+
+def main_process() -> bool:
+ return get_rank() == 0
+ #return not cfg.distributed or \
+ # (cfg.distributed and cfg.local_rank == 0)
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+def _find_free_port():
+ # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(('', 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+def _is_free_port(port):
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
+ ips.append('localhost')
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
+
+
+# def collect_env():
+# """Collect the information of the running environments."""
+# env_info = collect_base_env()
+# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+# return env_info
+
+def init_env(launcher, cfg):
+ """Initialize distributed training environment.
+ If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ """
+ if launcher == 'slurm':
+ _init_dist_slurm(cfg)
+ elif launcher == 'ror':
+ _init_dist_ror(cfg)
+ elif launcher == 'None':
+ _init_none_dist(cfg)
+ else:
+ raise RuntimeError(f'{cfg.launcher} has not been supported!')
+
+def _init_none_dist(cfg):
+ cfg.dist_params.num_gpus_per_node = 1
+ cfg.dist_params.world_size = 1
+ cfg.dist_params.nnodes = 1
+ cfg.dist_params.node_rank = 0
+ cfg.dist_params.global_rank = 0
+ cfg.dist_params.local_rank = 0
+ os.environ["WORLD_SIZE"] = str(1)
+
+def _init_dist_ror(cfg):
+ from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size
+ cfg.dist_params.num_gpus_per_node = get_local_size()
+ cfg.dist_params.world_size = get_world_size()
+ cfg.dist_params.nnodes = (get_world_size()) // (get_local_size())
+ cfg.dist_params.node_rank = get_node_rank()
+ cfg.dist_params.global_rank = get_world_rank()
+ cfg.dist_params.local_rank = get_local_rank()
+ os.environ["WORLD_SIZE"] = str(get_world_size())
+
+
+def _init_dist_slurm(cfg):
+ if 'NNODES' not in os.environ:
+ os.environ['NNODES'] = str(cfg.dist_params.nnodes)
+ if 'NODE_RANK' not in os.environ:
+ os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank)
+
+ #cfg.dist_params.
+ num_gpus = torch.cuda.device_count()
+ world_size = int(os.environ['NNODES']) * num_gpus
+ os.environ['WORLD_SIZE'] = str(world_size)
+
+ # config port
+ if 'MASTER_PORT' in os.environ:
+ master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ else:
+ # if torch.distributed default port(29500) is available
+ # then use it, else find a free port
+ if _is_free_port(16500):
+ master_port = '16500'
+ else:
+ master_port = str(_find_free_port())
+ os.environ['MASTER_PORT'] = master_port
+
+ # config addr
+ if 'MASTER_ADDR' in os.environ:
+ master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ # elif cfg.dist_params.dist_url is not None:
+ # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2])
+ else:
+ master_addr = '127.0.0.1' #'tcp://127.0.0.1'
+ os.environ['MASTER_ADDR'] = master_addr
+
+ # set dist_url to 'env://'
+ cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}"
+
+ cfg.dist_params.num_gpus_per_node = num_gpus
+ cfg.dist_params.world_size = world_size
+ cfg.dist_params.nnodes = int(os.environ['NNODES'])
+ cfg.dist_params.node_rank = int(os.environ['NODE_RANK'])
+
+ # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"):
+ # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://")
+
+
+def get_func(func_name):
+ """
+ Helper to return a function object by name. func_name must identify
+ a function in this module or the path to a function relative to the base
+ module.
+ @ func_name: function name.
+ """
+ if func_name == '':
+ return None
+ try:
+ parts = func_name.split('.')
+ # Refers to a function in this module
+ if len(parts) == 1:
+ return globals()[parts[0]]
+ # Otherwise, assume we're referencing a module under modeling
+ module_name = '.'.join(parts[:-1])
+ module = importlib.import_module(module_name)
+ return getattr(module, parts[-1])
+ except:
+ raise RuntimeError(f'Failed to find function: {func_name}')
+
+class Timer(object):
+ """A simple timer."""
+
+ def __init__(self):
+ self.reset()
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.diff = time.time() - self.start_time
+ self.total_time += self.diff
+ self.calls += 1
+ self.average_time = self.total_time / self.calls
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def reset(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
+
+class TrainingStats(object):
+ """Track vital training statistics."""
+ def __init__(self, log_period, tensorboard_logger=None):
+ self.log_period = log_period
+ self.tblogger = tensorboard_logger
+ self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time']
+ self.iter_timer = Timer()
+ # Window size for smoothing tracked values (with median filtering)
+ self.filter_size = log_period
+ def create_smoothed_value():
+ return AverageMeter()
+ self.smoothed_losses = defaultdict(create_smoothed_value)
+ #self.smoothed_metrics = defaultdict(create_smoothed_value)
+ #self.smoothed_total_loss = AverageMeter()
+
+
+ def IterTic(self):
+ self.iter_timer.tic()
+
+ def IterToc(self):
+ return self.iter_timer.toc(average=False)
+
+ def reset_iter_time(self):
+ self.iter_timer.reset()
+
+ def update_iter_stats(self, losses_dict):
+ """Update tracked iteration statistics."""
+ for k, v in losses_dict.items():
+ self.smoothed_losses[k].update(float(v), 1)
+
+ def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}):
+ """Log the tracked statistics."""
+ if (cur_iter % self.log_period == 0):
+ stats = self.get_stats(cur_iter, optimizer, max_iters, val_err)
+ log_stats(stats)
+ if self.tblogger:
+ self.tb_log_stats(stats, cur_iter)
+ for k, v in self.smoothed_losses.items():
+ v.reset()
+
+ def tb_log_stats(self, stats, cur_iter):
+ """Log the tracked statistics to tensorboard"""
+ for k in stats:
+ # ignore some logs
+ if k not in self.tb_ignored_keys:
+ v = stats[k]
+ if isinstance(v, dict):
+ self.tb_log_stats(v, cur_iter)
+ else:
+ self.tblogger.add_scalar(k, v, cur_iter)
+
+
+ def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}):
+ eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter)
+
+ eta = str(datetime.timedelta(seconds=int(eta_seconds)))
+ stats = OrderedDict(
+ iter=cur_iter, # 1-indexed
+ time=self.iter_timer.average_time,
+ eta=eta,
+ )
+ optimizer_state_dict = optimizer.state_dict()
+ lr = {}
+ for i in range(len(optimizer_state_dict['param_groups'])):
+ lr_name = 'group%d_lr' % i
+ lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr']
+
+ stats['lr'] = OrderedDict(lr)
+ for k, v in self.smoothed_losses.items():
+ stats[k] = v.avg
+
+ stats['val_err'] = OrderedDict(val_err)
+ stats['max_iters'] = max_iters
+ return stats
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ @average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+def log_stats(stats):
+ logger = logging.getLogger()
+ """Log training statistics to terminal"""
+ lines = "[Step %d/%d]\n" % (
+ stats['iter'], stats['max_iters'])
+
+ lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % (
+ stats['total_loss'], stats['time'], stats['eta'])
+
+ # log loss
+ lines += "\t\t"
+ for k, v in stats.items():
+ if 'loss' in k.lower() and 'total_loss' not in k.lower():
+ lines += "%s: %.3f" % (k, v) + ", "
+ lines = lines[:-3]
+ lines += '\n'
+
+ # validate criteria
+ lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", "
+ lines += '\n'
+
+ # lr in different groups
+ lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items())
+ lines += '\n'
+ logger.info(lines[:-1]) # remove last new linen_pxl
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/custom_data.py b/src/custom_controlnet_aux/metric3d/mono/utils/custom_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fab47478bc471c51b5454cc15550079ebec21b
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/custom_data.py
@@ -0,0 +1,34 @@
+import glob
+import os
+import json
+import cv2
+
+def load_from_annos(anno_path):
+ with open(anno_path, 'r') as f:
+ annos = json.load(f)['files']
+
+ datas = []
+ for i, anno in enumerate(annos):
+ rgb = anno['rgb']
+ depth = anno['depth'] if 'depth' in anno else None
+ depth_scale = anno['depth_scale'] if 'depth_scale' in anno else 1.0
+ intrinsic = anno['cam_in'] if 'cam_in' in anno else None
+ normal = anno['normal'] if 'normal' in anno else None
+
+ data_i = {
+ 'rgb': rgb,
+ 'depth': depth,
+ 'depth_scale': depth_scale,
+ 'intrinsic': intrinsic,
+ 'filename': os.path.basename(rgb),
+ 'folder': rgb.split('/')[-3],
+ 'normal': normal
+ }
+ datas.append(data_i)
+ return datas
+
+def load_data(path: str):
+ rgbs = glob.glob(path + '/*.jpg') + glob.glob(path + '/*.png')
+ #intrinsic = [835.8179931640625, 835.8179931640625, 961.5419921875, 566.8090209960938] #[721.53769, 721.53769, 609.5593, 172.854]
+ data = [{'rgb': i, 'depth': None, 'intrinsic': None, 'filename': os.path.basename(i), 'folder': i.split('/')[-3]} for i in rgbs]
+ return data
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/do_test.py b/src/custom_controlnet_aux/metric3d/mono/utils/do_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cccd2707aa09a552848e8f53da437e421f2bf544
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/do_test.py
@@ -0,0 +1,380 @@
+import torch
+import torch.nn.functional as F
+import logging
+import os
+import os.path as osp
+from .avg_meter import MetricAverageMeter
+from .visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs
+import cv2
+from tqdm import tqdm
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+def to_cuda(data: dict):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = v.cuda(non_blocking=True)
+ if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor):
+ for i, l_i in enumerate(v):
+ data[k][i] = l_i.cuda(non_blocking=True)
+ return data
+
+def align_scale(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ if torch.sum(mask) > 10:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ else:
+ scale = 1
+ pred_scaled = pred * scale
+ return pred_scaled, scale
+
+def align_scale_shift(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ target_mask = target[mask].cpu().numpy()
+ pred_mask = pred[mask].cpu().numpy()
+ if torch.sum(mask) > 10:
+ scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
+ if scale < 0:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ shift = 0
+ else:
+ scale = 1
+ shift = 0
+ pred = pred * scale + shift
+ return pred, scale
+
+def align_scale_shift_numpy(pred: np.array, target: np.array):
+ mask = target > 0
+ target_mask = target[mask]
+ pred_mask = pred[mask]
+ if np.sum(mask) > 10:
+ scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
+ if scale < 0:
+ scale = np.median(target[mask]) / (np.median(pred[mask]) + 1e-8)
+ shift = 0
+ else:
+ scale = 1
+ shift = 0
+ pred = pred * scale + shift
+ return pred, scale
+
+
+def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
+ """
+ Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W]
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center / (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio):
+ """
+ Resize the input.
+ Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the
+ label will be scaled with the resize factor.
+ """
+ padding = [123.675, 116.28, 103.53]
+ h, w, _ = image.shape
+ resize_ratio_h = output_shape[0] / canonical_shape[0]
+ resize_ratio_w = output_shape[1] / canonical_shape[1]
+ to_scale_ratio = min(resize_ratio_h, resize_ratio_w)
+
+ resize_ratio = to_canonical_ratio * to_scale_ratio
+
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h = max(output_shape[0] - reshape_h, 0)
+ pad_w = max(output_shape[1] - reshape_w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=padding)
+
+ # Resize, adjust principle point
+ intrinsic[2] = intrinsic[2] * to_scale_ratio
+ intrinsic[3] = intrinsic[3] * to_scale_ratio
+
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=-1)
+
+ pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ label_scale_factor=1/to_scale_ratio
+ return image, cam_model, pad, label_scale_factor
+
+
+def get_prediction(
+ model: torch.nn.Module,
+ input: torch.tensor,
+ cam_model: torch.tensor,
+ pad_info: torch.tensor,
+ scale_info: torch.tensor,
+ gt_depth: torch.tensor,
+ normalize_scale: float,
+ ori_shape: list=[],
+):
+
+ data = dict(
+ input=input,
+ cam_model=cam_model,
+ )
+ pred_depth, confidence, output_dict = model.inference(data)
+
+ return pred_depth, confidence, output_dict
+
+def transform_test_data_scalecano(rgb, intrinsic, data_basic, device="cuda"):
+ """
+ Pre-process the input for forwarding. Employ `label scale canonical transformation.'
+ Args:
+ rgb: input rgb image. [H, W, 3]
+ intrinsic: camera intrinsic parameter, [fx, fy, u0, v0]
+ data_basic: predefined canonical space in configs.
+ """
+ canonical_space = data_basic['canonical_space']
+ forward_size = data_basic.crop_size
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
+
+ # BGR to RGB
+ #rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
+
+ ori_h, ori_w, _ = rgb.shape
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2
+ canonical_focal = canonical_space['focal_length']
+
+ cano_label_scale_ratio = canonical_focal / ori_focal
+
+ canonical_intrinsic = [
+ intrinsic[0] * cano_label_scale_ratio,
+ intrinsic[1] * cano_label_scale_ratio,
+ intrinsic[2],
+ intrinsic[3],
+ ]
+
+ # resize
+ rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0)
+
+ # label scale factor
+ label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio
+
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
+ rgb = torch.div((rgb - mean), std)
+ rgb = rgb.to(device)
+
+ cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float()
+ cam_model = cam_model[None, :, :, :].to(device)
+ cam_model_stacks = [
+ torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False)
+ for i in [2, 4, 8, 16, 32]
+ ]
+ return rgb, cam_model_stacks, pad, label_scale_factor
+
+def do_scalecano_test_with_custom_data(
+ model: torch.nn.Module,
+ cfg: dict,
+ test_data: list,
+ logger: logging.RootLogger,
+ is_distributed: bool = True,
+ local_rank: int = 0,
+ bs: int = 2, # Batch size parameter
+):
+
+ show_dir = cfg.show_dir
+ save_interval = 1
+ save_imgs_dir = show_dir + '/vis'
+ os.makedirs(save_imgs_dir, exist_ok=True)
+ save_pcd_dir = show_dir + '/pcd'
+ os.makedirs(save_pcd_dir, exist_ok=True)
+
+ normalize_scale = cfg.data_basic.depth_range[1]
+ dam = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+ dam_median = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+ dam_global = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+
+ # Process data in batches
+ for i in tqdm(range(0, len(test_data), bs)):
+ batch_data = test_data[i:i + bs] # Extract batch
+ rgb_inputs, pads, label_scale_factors, gt_depths, rgb_origins = [], [], [], [], []
+
+ for an in batch_data:
+ print(an['rgb'])
+ rgb_origin = cv2.imread(an['rgb'])[:, :, ::-1].copy()
+ rgb_origins.append(rgb_origin)
+ gt_depth = None
+ if an['depth'] is not None:
+ gt_depth = cv2.imread(an['depth'], -1)
+ gt_depth_scale = an['depth_scale']
+ gt_depth = gt_depth / gt_depth_scale
+ gt_depths.append(gt_depth)
+
+ intrinsic = an['intrinsic']
+ if intrinsic is None:
+ intrinsic = [1000.0, 1000.0, rgb_origin.shape[1]/2, rgb_origin.shape[0]/2]
+
+ rgb_input, _, pad, label_scale_factor = transform_test_data_scalecano(rgb_origin, intrinsic, cfg.data_basic)
+ rgb_inputs.append(rgb_input)
+ pads.append(pad)
+ label_scale_factors.append(label_scale_factor)
+
+ # Process the batch
+ pred_depths, outputs = get_prediction(
+ model=model,
+ input=torch.stack(rgb_inputs), # Stack inputs for batch processing
+ cam_model=None,
+ pad_info=pads,
+ scale_info=None,
+ gt_depth=None,
+ normalize_scale=None,
+ )
+
+ for j, gt_depth in enumerate(gt_depths):
+ normal_out = None
+ if 'normal_out_list' in outputs.keys():
+ normal_out = outputs['normal_out_list'][0][j, :]
+
+ postprocess_per_image(
+ i*bs+j,
+ pred_depths[j, :],
+ gt_depth,
+ intrinsic,
+ rgb_origins[j],
+ normal_out,
+ pads[j],
+ batch_data[j],
+ dam,
+ dam_median,
+ dam_global,
+ is_distributed,
+ save_imgs_dir,
+ save_pcd_dir,
+ normalize_scale,
+ label_scale_factors[j],
+ )
+
+ #if gt_depth_flag:
+ if False:
+ eval_error = dam.get_metrics()
+ print('w/o match :', eval_error)
+
+ eval_error_median = dam_median.get_metrics()
+ print('median match :', eval_error_median)
+
+ eval_error_global = dam_global.get_metrics()
+ print('global match :', eval_error_global)
+ else:
+ print('missing gt_depth, only save visualizations...')
+
+
+def postprocess_per_image(i, pred_depth, gt_depth, intrinsic, rgb_origin, normal_out, pad, an, dam, dam_median, dam_global, is_distributed, save_imgs_dir, save_pcd_dir, normalize_scale, scale_info):
+
+ pred_depth = pred_depth.squeeze()
+ pred_depth = pred_depth[pad[0] : pred_depth.shape[0] - pad[1], pad[2] : pred_depth.shape[1] - pad[3]]
+ pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], [rgb_origin.shape[0], rgb_origin.shape[1]], mode='bilinear').squeeze() # to original size
+ pred_depth = pred_depth * normalize_scale / scale_info
+
+ pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth
+ if gt_depth is not None:
+
+ pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], (gt_depth.shape[0], gt_depth.shape[1]), mode='bilinear').squeeze() # to original size
+
+ gt_depth = torch.from_numpy(gt_depth).cuda()
+
+ pred_depth_median = pred_depth * gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median()
+ pred_global, _ = align_scale_shift(pred_depth, gt_depth)
+
+ mask = (gt_depth > 1e-8)
+ dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed)
+ dam_median.update_metrics_gpu(pred_depth_median, gt_depth, mask, is_distributed)
+ dam_global.update_metrics_gpu(pred_global, gt_depth, mask, is_distributed)
+ print(gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median(), )
+
+ os.makedirs(osp.join(save_imgs_dir, an['folder']), exist_ok=True)
+ rgb_torch = torch.from_numpy(rgb_origin).to(pred_depth.device).permute(2, 0, 1)
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None].to(rgb_torch.device)
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None].to(rgb_torch.device)
+ rgb_torch = torch.div((rgb_torch - mean), std)
+
+ save_val_imgs(
+ i,
+ pred_depth,
+ gt_depth if gt_depth is not None else torch.ones_like(pred_depth, device=pred_depth.device),
+ rgb_torch,
+ osp.join(an['folder'], an['filename']),
+ save_imgs_dir,
+ )
+ #save_raw_imgs(pred_depth.detach().cpu().numpy(), rgb_torch, osp.join(an['folder'], an['filename']), save_imgs_dir, 1000.0)
+
+ # pcd
+ pred_depth = pred_depth.detach().cpu().numpy()
+ #pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3])
+ #os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True)
+ #save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4]+'.ply'))
+
+ if an['intrinsic'] == None:
+ #for r in [0.9, 1.0, 1.1]:
+ for r in [1.0]:
+ #for f in [600, 800, 1000, 1250, 1500]:
+ for f in [1000]:
+ pcd = reconstruct_pcd(pred_depth, f * r, f * (2-r), intrinsic[2], intrinsic[3])
+ fstr = '_fx_' + str(int(f * r)) + '_fy_' + str(int(f * (2-r)))
+ os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True)
+ save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4] + fstr +'.ply'))
+
+ if normal_out is not None:
+ pred_normal = normal_out[:3, :, :] # (3, H, W)
+ H, W = pred_normal.shape[1:]
+ pred_normal = pred_normal[ :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+
+ gt_normal = None
+ #if gt_normal_flag:
+ if False:
+ pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True)
+ gt_normal = cv2.imread(norm_path)
+ gt_normal = cv2.cvtColor(gt_normal, cv2.COLOR_BGR2RGB)
+ gt_normal = np.array(gt_normal).astype(np.uint8)
+ gt_normal = ((gt_normal.astype(np.float32) / 255.0) * 2.0) - 1.0
+ norm_valid_mask = (np.linalg.norm(gt_normal, axis=2, keepdims=True) > 0.5)
+ gt_normal = gt_normal * norm_valid_mask
+ gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)
+ dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal
+
+ save_normal_val_imgs(iter,
+ pred_normal,
+ gt_normal if gt_normal is not None else torch.ones_like(pred_normal, device=pred_normal.device),
+ rgb_torch, # data['input'],
+ osp.join(an['folder'], 'normal_'+an['filename']),
+ save_imgs_dir,
+ )
+
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/logger.py b/src/custom_controlnet_aux/metric3d/mono/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca48c613b2fdc5352b13ccb7d0bfdc1df5e3b531
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/logger.py
@@ -0,0 +1,102 @@
+import atexit
+import logging
+import os
+import sys
+import time
+import torch
+from termcolor import colored
+
+__all__ = ["setup_logger", ]
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+def setup_logger(
+ output=None, distributed_rank=0, *, name='metricdepth', color=True, abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "DEBUG".
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ abbrev_name (str): an abbreviation of the module, to avoid log names in logs.
+ Set to "" not log the root module in logs.
+ By default, will abbreviate "detectron2" to "d2" and leave other
+ modules unchanged.
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = "d2"
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s %(message)s ", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + ".rank{}".format(distributed_rank)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+
+ return logger
+
+from iopath.common.file_io import PathManager as PathManagerBase
+
+
+PathManager = PathManagerBase()
+
+# cache the opened file object, so that different calls to 'setup_logger
+# with the same file name can safely write to the same file.
+def _cached_log_stream(filename):
+ # use 1K buffer if writting to cloud storage
+ io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
+ atexit.register(io.close)
+ return io
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/mldb.py b/src/custom_controlnet_aux/metric3d/mono/utils/mldb.py
new file mode 100644
index 0000000000000000000000000000000000000000..d74ac53fd0302e2e954105bade52e6de4c18e2f6
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/mldb.py
@@ -0,0 +1,34 @@
+from types import ModuleType
+import data_info
+
+def load_data_info(module_name, data_info={}, mldb_type='mldb_info', module=None):
+ if module is None:
+ module = globals().get(module_name, None)
+ if module:
+ for key, value in module.__dict__.items():
+ if not (key.startswith('__')) and not (key.startswith('_')):
+ if key == 'mldb_info':
+ data_info.update(value)
+ elif isinstance(value, ModuleType):
+ load_data_info(module_name + '.' + key, data_info, module=value)
+ else:
+ raise RuntimeError(f'Try to access "mldb_info", but cannot find {module_name} module.')
+
+def reset_ckpt_path(cfg, data_info):
+ if isinstance(cfg, dict):
+ for key in cfg.keys():
+ if key == 'backbone':
+ new_ckpt_path = data_info['checkpoint']['mldb_root'] + '/' + data_info['checkpoint'][cfg.backbone.type]
+ cfg.backbone.update(checkpoint=new_ckpt_path)
+ continue
+ elif isinstance(cfg.get(key), dict):
+ reset_ckpt_path(cfg.get(key), data_info)
+ else:
+ continue
+ else:
+ return
+
+if __name__ == '__main__':
+ mldb_info_tmp = {}
+ load_data_info('mldb_data_info', mldb_info_tmp)
+ print('results', mldb_info_tmp.keys())
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py b/src/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d26314d806ea961f6bf09d1fb195bf5e364f181
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py
@@ -0,0 +1,24 @@
+import open3d as o3d
+import numpy as np
+
+def downsample_and_filter(pcd_file):
+ pcd = o3d.io.read_point_cloud(pcd_file, max_bound_div = 750, neighbor_num = 8)
+ point_num = len(pcd.points)
+ if (point_num > 10000000):
+ voxel_down_pcd = o3d.geometry.PointCloud.uniform_down_sample(pcd, int(point_num / 10000000)+1)
+ else:
+ voxel_down_pcd = pcd
+ max_bound = voxel_down_pcd.get_max_bound()
+ ball_radius = np.linalg.norm(max_bound) / max_bound_div
+ pcd_filter, _ = voxel_down_pcd.remove_radius_outlier(neighbor_num, ball_radius)
+ print('filtered size', len(pcd_filter.points), 'pre size:', len(pcd.points))
+ o3d.io.write_point_cloud(pcd_file[:-4] + '_filtered.ply', pcd_filter)
+
+
+if __name__ == "__main__":
+ import os
+ dir_path = './data/demo_pcd'
+ for pcd_file in os.listdir(dir_path):
+ #if 'jonathan' in pcd_file: set max_bound_div to 300 and neighbot_num to 8
+ downsample_and_filter(os.path.join(dir_path, pcd_file))
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/running.py b/src/custom_controlnet_aux/metric3d/mono/utils/running.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dce71cf64155bbb6fa269efabfc841ab75f68b4
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/running.py
@@ -0,0 +1,77 @@
+import os
+import torch
+import torch.nn as nn
+from .comm import main_process
+import copy
+import inspect
+import logging
+import glob
+
+
+def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
+ """
+ Load the check point for resuming training or finetuning.
+ """
+ logger = logging.getLogger()
+ if os.path.isfile(load_path):
+ if main_process():
+ logger.info(f"Loading weight '{load_path}'")
+ checkpoint = torch.load(load_path, map_location="cpu", weights_only=True)
+ ckpt_state_dict = checkpoint['model_state_dict']
+ model.load_state_dict(ckpt_state_dict, strict=strict_match)
+
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if scheduler is not None:
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if loss_scaler is not None and 'scaler' in checkpoint:
+ scheduler.load_state_dict(checkpoint['scaler'])
+ del ckpt_state_dict
+ del checkpoint
+ if main_process():
+ logger.info(f"Successfully loaded weight: '{load_path}'")
+ if scheduler is not None and optimizer is not None:
+ logger.info(f"Resume training from: '{load_path}'")
+ else:
+ if main_process():
+ raise RuntimeError(f"No weight found at '{load_path}'")
+ return model, optimizer, scheduler, loss_scaler
+
+
+def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None):
+ """
+ Save the model, optimizer, lr scheduler.
+ """
+ logger = logging.getLogger()
+
+ if 'IterBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_iters
+ elif 'EpochBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_epochs
+ else:
+ raise TypeError(f'{cfg.runner.type} is not supported')
+
+ ckpt = dict(
+ model_state_dict=model.module.state_dict(),
+ optimizer=optimizer.state_dict(),
+ max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \
+ else cfg.runner.max_epochs,
+ scheduler=scheduler.state_dict(),
+ )
+
+ if loss_scaler is not None:
+ ckpt.update(dict(scaler=loss_scaler.state_dict()))
+
+ ckpt_dir = os.path.join(cfg.work_dir, 'ckpt')
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter)
+ saved_ckpts = glob.glob(ckpt_dir + '/step*.pth')
+ torch.save(ckpt, save_name)
+
+ # keep the last 8 ckpts
+ if len(saved_ckpts) > 20:
+ saved_ckpts.sort()
+ os.remove(saved_ckpts.pop(0))
+
+ logger.info(f'Save model: {save_name}')
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/transform.py b/src/custom_controlnet_aux/metric3d/mono/utils/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af94efe754d6f72325db6fdc170f30fbfb8c2fe
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/transform.py
@@ -0,0 +1,408 @@
+import collections
+import cv2
+import math
+import numpy as np
+import numbers
+import random
+import torch
+
+import matplotlib
+import matplotlib.cm
+
+
+"""
+Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default)
+for image manipulation.
+"""
+
+class Compose(object):
+ # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()])
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ for t in self.transforms:
+ images, labels, intrinsics, cam_models, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, other_labels, transform_paras)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class ToTensor(object):
+ # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list):
+ raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists."))
+ if len(images) != len(intrinsics):
+ raise (RuntimeError("Numbers of images and intrinsics are not matched."))
+ if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray):
+ raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label."
+ "[eg: data readed by cv2.imread()].\n"))
+ if not isinstance(intrinsics[0], list):
+ raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics"))
+
+ if len(images[0].shape) > 3 or len(images[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n"))
+ if len(labels[0].shape) > 3 or len(labels[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n"))
+
+ if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3:
+ raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n"))
+
+ for i, img in enumerate(images):
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, axis=2)
+ images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float()
+ for i, lab in enumerate(labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ labels[i] = torch.from_numpy(lab).float()
+ for i, intrinsic in enumerate(intrinsics):
+ if len(intrinsic) == 3:
+ intrinsic = [intrinsic[0],] + intrinsic
+ intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float)
+ if cam_models is not None:
+ for i, cam_model in enumerate(cam_models):
+ cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None
+ if other_labels is not None:
+ for i, lab in enumerate(other_labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ other_labels[i] = torch.from_numpy(lab).float()
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class Normalize(object):
+ # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
+ def __init__(self, mean, std=None, **kwargs):
+ if std is None:
+ assert len(mean) > 0
+ else:
+ assert len(mean) == len(std)
+ self.mean = torch.tensor(mean).float()[:, None, None]
+ self.std = torch.tensor(std).float()[:, None, None] if std is not None \
+ else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None]
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ # if self.std is None:
+ # # for t, m in zip(image, self.mean):
+ # # t.sub(m)
+ # image = image - self.mean
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = ref_i - self.mean
+ # else:
+ # # for t, m, s in zip(image, self.mean, self.std):
+ # # t.sub(m).div(s)
+ # image = (image - self.mean) / self.std
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = (ref_i - self.mean) / self.std
+ for i, img in enumerate(images):
+ img = torch.div((img - self.mean), self.std)
+ images[i] = img
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class LableScaleCanonical(object):
+ """
+ To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are
+ mapped to a canonical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, **kwargs):
+ self.canonical_focal = kwargs['focal_length']
+
+ def _get_scale_ratio(self, intrinsic):
+ target_focal_x = intrinsic[0]
+ label_scale_ratio = self.canonical_focal / target_focal_x
+ pose_scale_ratio = 1.0
+ return label_scale_ratio, pose_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float32
+
+ label_scale_ratio = None
+ pose_scale_ratio = None
+
+ for i in range(len(intrinsics)):
+ img_i = images[i]
+ label_i = labels[i] if i < len(labels) else None
+ intrinsic_i = intrinsics[i].copy()
+ cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+
+ label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i)
+
+ # adjust the focal length, map the current camera to the canonical space
+ intrinsics[i] = [intrinsic_i[0] * label_scale_ratio, intrinsic_i[1] * label_scale_ratio, intrinsic_i[2], intrinsic_i[3]]
+
+ # scale the label to the canonical space
+ if label_i is not None:
+ labels[i] = label_i * label_scale_ratio
+
+ if cam_model_i is not None:
+ # As the focal length is adjusted (canonical focal length), the camera model should be re-built
+ ori_h, ori_w, _ = img_i.shape
+ cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i])
+
+
+ if transform_paras is not None:
+ transform_paras.update(label_scale_factor=label_scale_ratio, focal_scale_factor=label_scale_ratio)
+
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class ResizeKeepRatio(object):
+ """
+ Resize and pad to a given size. Hold the aspect ratio.
+ This resizing assumes that the camera model remains unchanged.
+ Args:
+ resize_size: predefined output size.
+ """
+ def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs):
+ if isinstance(resize_size, int):
+ self.resize_h = resize_size
+ self.resize_w = resize_size
+ elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \
+ and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \
+ and resize_size[0] > 0 and resize_size[1] > 0:
+ self.resize_h = resize_size[0]
+ self.resize_w = resize_size[1]
+ else:
+ raise (RuntimeError("crop size error.\n"))
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+ # self.crop_size = kwargs['crop_size']
+ self.canonical_focal = kwargs['focal_length']
+
+ def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio):
+ """
+ Resize data first and then do the padding.
+ 'label' will be scaled.
+ """
+ h, w, _ = image.shape
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h, pad_w, pad_h_half, pad_w_half = padding
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.padding)
+
+ if label is not None:
+ # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ label = cv2.copyMakeBorder(
+ label,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+ # scale the label
+ label = label / to_scale_ratio
+
+ # Resize, adjust principle point
+ if intrinsic is not None:
+ intrinsic[0] = intrinsic[0] * resize_ratio / to_scale_ratio
+ intrinsic[1] = intrinsic[1] * resize_ratio / to_scale_ratio
+ intrinsic[2] = intrinsic[2] * resize_ratio
+ intrinsic[3] = intrinsic[3] * resize_ratio
+
+ if cam_model is not None:
+ #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ # Pad, adjust the principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half
+ intrinsic[3] = intrinsic[3] + pad_h_half
+ return image, label, intrinsic, cam_model
+
+ def get_label_scale_factor(self, image, intrinsic, resize_ratio):
+ ori_h, ori_w, _ = image.shape
+ # crop_h, crop_w = self.crop_size
+ ori_focal = intrinsic[0]
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+ to_scale_ratio = resize_ratio / to_canonical_ratio
+ return to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ target_h, target_w, _ = images[0].shape
+ resize_ratio_h = self.resize_h / target_h
+ resize_ratio_w = self.resize_w / target_w
+ resize_ratio = min(resize_ratio_h, resize_ratio_w)
+ reshape_h = int(resize_ratio * target_h)
+ reshape_w = int(resize_ratio * target_w)
+ pad_h = max(self.resize_h - reshape_h, 0)
+ pad_w = max(self.resize_w - reshape_w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ pad_info = [pad_h, pad_w, pad_h_half, pad_w_half]
+ to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio)
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if other_labels is not None:
+
+ for i, other_lab in enumerate(other_labels):
+ # resize
+ other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ other_labels[i] = cv2.copyMakeBorder(
+ other_lab,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ if transform_paras is not None:
+ pad_old = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ new_pad = [pad_old[0] + pad[0], pad_old[1] + pad[1], pad_old[2] + pad[2], pad_old[3] + pad[3]]
+ transform_paras.update(dict(pad=new_pad))
+ if 'label_scale_factor' in transform_paras:
+ transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio
+ else:
+ transform_paras.update(label_scale_factor=1.0/to_scale_ratio)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class BGR2RGB(object):
+ # Converts image from BGR order to RGB order, for model initialized from Pytorch
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None,other_labels=None, transform_paras=None):
+ for i, img in enumerate(images):
+ images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+def resize_depth_preserve(depth, shape):
+ """
+ Resizes depth map preserving all valid depth pixels
+ Multiple downsampled points can be assigned to the same pixel.
+
+ Parameters
+ ----------
+ depth : np.array [h,w]
+ Depth map
+ shape : tuple (H,W)
+ Output shape
+
+ Returns
+ -------
+ depth : np.array [H,W,1]
+ Resized depth map
+ """
+ # Store dimensions and reshapes to single column
+ depth = np.squeeze(depth)
+ h, w = depth.shape
+ x = depth.reshape(-1)
+ # Create coordinate grid
+ uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
+ # Filters valid points
+ idx = x > 0
+ crd, val = uv[idx], x[idx]
+ # Downsamples coordinates
+ crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32)
+ crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32)
+ # Filters points inside image
+ idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
+ crd, val = crd[idx], val[idx]
+ # Creates downsampled depth image and assigns points
+ depth = np.zeros(shape)
+ depth[crd[:, 0], crd[:, 1]] = val
+ # Return resized depth map
+ return depth
+
+
+def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
+ """
+ Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center/ (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+def gray_to_colormap(img, cmap='rainbow'):
+ """
+ Transfer gray map to matplotlib colormap
+ """
+ assert img.ndim == 2
+
+ img[img<0] = 0
+ mask_invalid = img < 1e-10
+ img = img / (img.max() + 1e-8)
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
+ cmap_m = matplotlib.cm.get_cmap(cmap)
+ map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
+ colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
+ colormap[mask_invalid] = 0
+ return colormap
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py b/src/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0986d482a2ec68be1dd65719adec662272b833c
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py
@@ -0,0 +1,88 @@
+import numpy as np
+import torch
+from plyfile import PlyData, PlyElement
+import cv2
+
+
+def get_pcd_base(H, W, u0, v0, fx, fy):
+ x_row = np.arange(0, W)
+ x = np.tile(x_row, (H, 1))
+ x = x.astype(np.float32)
+ u_m_u0 = x - u0
+
+ y_col = np.arange(0, H) # y_col = np.arange(0, height)
+ y = np.tile(y_col, (W, 1)).T
+ y = y.astype(np.float32)
+ v_m_v0 = y - v0
+
+ x = u_m_u0 / fx
+ y = v_m_v0 / fy
+ z = np.ones_like(x)
+ pw = np.stack([x, y, z], axis=2) # [h, w, c]
+ return pw
+
+
+def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None):
+ if type(depth) == torch.__name__:
+ depth = depth.cpu().numpy().squeeze()
+ depth = cv2.medianBlur(depth, 5)
+ if pcd_base is None:
+ H, W = depth.shape
+ pcd_base = get_pcd_base(H, W, u0, v0, fx, fy)
+ pcd = depth[:, :, None] * pcd_base
+ if mask:
+ pcd[mask] = 0
+ return pcd
+
+
+def save_point_cloud(pcd, rgb, filename, binary=True):
+ """Save an RGB point cloud as a PLY file.
+ :paras
+ @pcd: Nx3 matrix, the XYZ coordinates
+ @rgb: Nx3 matrix, the rgb colors for each 3D point
+ """
+ assert pcd.shape[0] == rgb.shape[0]
+
+ if rgb is None:
+ gray_concat = np.tile(np.array([128], dtype=np.uint8),
+ (pcd.shape[0], 3))
+ points_3d = np.hstack((pcd, gray_concat))
+ else:
+ points_3d = np.hstack((pcd, rgb))
+ python_types = (float, float, float, int, int, int)
+ npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'),
+ ('green', 'u1'), ('blue', 'u1')]
+ if binary is True:
+ # Format into Numpy structured array
+ vertices = []
+ for row_idx in range(points_3d.shape[0]):
+ cur_point = points_3d[row_idx]
+ vertices.append(
+ tuple(
+ dtype(point)
+ for dtype, point in zip(python_types, cur_point)))
+ vertices_array = np.array(vertices, dtype=npy_types)
+ el = PlyElement.describe(vertices_array, 'vertex')
+
+ # write
+ PlyData([el]).write(filename)
+ else:
+ x = np.squeeze(points_3d[:, 0])
+ y = np.squeeze(points_3d[:, 1])
+ z = np.squeeze(points_3d[:, 2])
+ r = np.squeeze(points_3d[:, 3])
+ g = np.squeeze(points_3d[:, 4])
+ b = np.squeeze(points_3d[:, 5])
+
+ ply_head = 'ply\n' \
+ 'format ascii 1.0\n' \
+ 'element vertex %d\n' \
+ 'property float x\n' \
+ 'property float y\n' \
+ 'property float z\n' \
+ 'property uchar red\n' \
+ 'property uchar green\n' \
+ 'property uchar blue\n' \
+ 'end_header' % r.shape[0]
+ # ---- Save ply data to disk
+ np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='')
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/metric3d/mono/utils/visualization.py b/src/custom_controlnet_aux/metric3d/mono/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..85523f802928a0ba2782bd0ab4871745f8597b5e
--- /dev/null
+++ b/src/custom_controlnet_aux/metric3d/mono/utils/visualization.py
@@ -0,0 +1,139 @@
+import matplotlib.pyplot as plt
+import os, cv2
+import numpy as np
+from .transform import gray_to_colormap
+import shutil
+import glob
+from .running import main_process
+import torch
+
+def save_raw_imgs(
+ pred: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ scale: float=200.0,
+ target: torch.tensor=None,
+ ):
+ """
+ Save raw GT, predictions, RGB in the same file.
+ """
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16))
+ if target is not None:
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16))
+
+
+def save_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ target: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb)
+ rgb = rgb.transpose((1, 2, 0))
+ cat_img = np.concatenate([rgb, pred_color, target_color], axis=0)
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+
+ # save to tensorboard
+ if tb_logger is not None:
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+
+def save_normal_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ targ: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None,
+ mask=None,
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :]
+ std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :]
+ pred = pred.squeeze()
+ targ = targ.squeeze()
+ rgb = rgb.squeeze()
+
+ if pred.size(0) == 3:
+ pred = pred.permute(1,2,0)
+ if targ.size(0) == 3:
+ targ = targ.permute(1,2,0)
+ if rgb.size(0) == 3:
+ rgb = rgb.permute(1,2,0)
+
+ pred_color = vis_surface_normal(pred, mask)
+ targ_color = vis_surface_normal(targ, mask)
+ rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8)
+
+ try:
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+ except:
+ pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0]))
+ targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0]))
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+ # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
+ # save to tensorboard
+ if tb_logger is not None:
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+
+def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor):
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+
+ pred = pred.squeeze().cpu().numpy()
+ target = target.squeeze().cpu().numpy()
+ rgb = rgb.squeeze().cpu().numpy()
+
+ pred[pred<0] = 0
+ target[target<0] = 0
+ max_scale = max(pred.max(), target.max())
+ pred_scale = (pred/max_scale * 10000).astype(np.uint16)
+ target_scale = (target/max_scale * 10000).astype(np.uint16)
+ pred_color = gray_to_colormap(pred)
+ target_color = gray_to_colormap(target)
+ pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1]))
+ target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1]))
+
+ rgb = ((rgb * std) + mean).astype(np.uint8)
+ return rgb, pred_scale, target_scale, pred_color, target_color
+
+
+def create_html(name2path, save_path='index.html', size=(256, 384)):
+ # table description
+ cols = []
+ for k, v in name2path.items():
+ col_i = Col('img', k, v) # specify image content for column
+ cols.append(col_i)
+ # html table generation
+ imagetable(cols, out_file=save_path, imsize=size)
+
+def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
+ """
+ Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
+ Aargs:
+ normal (torch.tensor, [h, w, 3]): surface normal
+ mask (torch.tensor, [h, w]): valid masks
+ """
+ normal = normal.cpu().numpy().squeeze()
+ n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
+ n_img_norm = normal / (n_img_L2 + 1e-8)
+ normal_vis = n_img_norm * 127
+ normal_vis += 128
+ normal_vis = normal_vis.astype(np.uint8)
+ if mask is not None:
+ mask = mask.cpu().numpy().squeeze()
+ normal_vis[~mask] = 0
+ return normal_vis
+
diff --git a/src/custom_controlnet_aux/midas/LICENSE b/src/custom_controlnet_aux/midas/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e
--- /dev/null
+++ b/src/custom_controlnet_aux/midas/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/src/custom_controlnet_aux/midas/__init__.py b/src/custom_controlnet_aux/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c58f390d63ec36b5f10cf863be645098e4394598
--- /dev/null
+++ b/src/custom_controlnet_aux/midas/__init__.py
@@ -0,0 +1,2 @@
+# Modern MiDaS implementation using HuggingFace transformers
+from .transformers import MidasDetector
diff --git a/src/custom_controlnet_aux/midas/transformers.py b/src/custom_controlnet_aux/midas/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cdac7b7f128232c8b08a9a1e4de0107cb96ce7
--- /dev/null
+++ b/src/custom_controlnet_aux/midas/transformers.py
@@ -0,0 +1,108 @@
+"""
+MiDaS implementation using HuggingFace transformers for PyTorch 2.7 compatibility.
+"""
+import numpy as np
+import torch
+import cv2
+from PIL import Image
+from typing import Union
+
+# Import utilities
+from ..util import HWC3, common_input_validate, resize_image_with_pad
+
+
+class MidasDetector:
+
+ def __init__(self, model_name="Intel/dpt-large"):
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
+
+ self.model_name = model_name
+ self.processor = DPTImageProcessor.from_pretrained(model_name)
+ self.model = DPTForDepthEstimation.from_pretrained(model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=None, model_type="dpt_hybrid", filename="dpt_hybrid-midas-501f0c75.pt"):
+ # Map legacy model types to HuggingFace models
+ model_mapping = {
+ "dpt_large": "Intel/dpt-large",
+ "dpt_hybrid": "Intel/dpt-hybrid-midas",
+ "midas_v21": "Intel/dpt-large",
+ "midas_v21_small": "Intel/dpt-large"
+ }
+
+ # Use filename for model selection if provided
+ if filename and isinstance(filename, str):
+ if "dpt_large" in filename.lower():
+ model_name = "Intel/dpt-large"
+ elif "dpt_hybrid" in filename.lower():
+ model_name = "Intel/dpt-hybrid-midas"
+ else:
+ model_name = model_mapping.get(model_type, "Intel/dpt-large")
+ else:
+ model_name = model_mapping.get(model_type, "Intel/dpt-large")
+
+ return cls(model_name)
+
+ def to(self, device):
+ self.model = self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ # Convert to PIL for processor
+ pil_image = Image.fromarray(detected_map.astype(np.uint8))
+
+ # Process with HuggingFace pipeline
+ with torch.no_grad():
+ inputs = self.processor(images=pil_image, return_tensors="pt")
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ outputs = self.model(**inputs)
+ depth = outputs.predicted_depth
+
+ # Normalize depth
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(1),
+ size=(detected_map.shape[0], detected_map.shape[1]),
+ mode="bicubic",
+ align_corners=False,
+ ).squeeze()
+
+ depth_pt = depth.clone()
+ depth_pt -= torch.min(depth_pt)
+ depth_pt /= torch.max(depth_pt)
+ depth_pt = depth_pt.cpu().numpy()
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+ if depth_and_normal:
+ depth_np = depth.cpu().numpy()
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+ z = np.ones_like(x) * a
+ x[depth_pt < bg_th] = 0
+ y[depth_pt < bg_th] = 0
+ normal = np.stack([x, y, z], axis=2)
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1]
+
+ depth_image = HWC3(depth_image)
+ if depth_and_normal:
+ normal_image = HWC3(normal_image)
+
+ depth_image = remove_pad(depth_image)
+ if depth_and_normal:
+ normal_image = remove_pad(normal_image)
+
+ if output_type == "pil":
+ depth_image = Image.fromarray(depth_image)
+ if depth_and_normal:
+ normal_image = Image.fromarray(normal_image)
+
+ if depth_and_normal:
+ return depth_image, normal_image
+ else:
+ return depth_image
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/midas/utils.py b/src/custom_controlnet_aux/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/src/custom_controlnet_aux/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/src/custom_controlnet_aux/mlsd/LICENSE b/src/custom_controlnet_aux/mlsd/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363
--- /dev/null
+++ b/src/custom_controlnet_aux/mlsd/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021-present NAVER Corp.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mlsd/__init__.py b/src/custom_controlnet_aux/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe3b0a972c953894e9e4e60fc62847559d7c8022
--- /dev/null
+++ b/src/custom_controlnet_aux/mlsd/__init__.py
@@ -0,0 +1,51 @@
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+
+
+class MLSDdetector:
+ def __init__(self, model):
+ self.model = model
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="mlsd_large_512_fp32.pth"):
+ subfolder = "annotator/ckpts" if pretrained_model_or_path == "lllyasviel/ControlNet" else ''
+ model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder)
+ model = MobileV2_MLSD_Large()
+ model.load_state_dict(torch.load(model_path), strict=True)
+ model.eval()
+
+ return cls(model)
+
+ def to(self, device):
+ self.model.to(device)
+ return self
+
+ def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, output_type="pil", upscale_method="INTER_AREA", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ img = detected_map
+ img_output = np.zeros_like(img)
+ try:
+ with torch.no_grad():
+ lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+ for line in lines:
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+ except Exception as e:
+ pass
+
+ detected_map = remove_pad(HWC3(img_output[:, :, 0]))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/mlsd/models/__init__.py b/src/custom_controlnet_aux/mlsd/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py b/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ if self.upscale:
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [1, 3, 6, 10, 13]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ if pretrained:
+ self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c1, c2, c3, c4, c5 = fpn_features
+ return c1, c2, c3, c4, c5
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Large(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Large, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=False)
+ ## A, B
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+ out_c1= 64, out_c2=64,
+ upscale=False)
+ self.block16 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
+ out_c1= 64, out_c2= 64)
+ self.block18 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block20 = BlockTypeB(128, 64)
+
+ ## A, B, C
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block22 = BlockTypeB(128, 64)
+
+ self.block23 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c1, c2, c3, c4, c5 = self.backbone(x)
+
+ x = self.block15(c4, c5)
+ x = self.block16(x)
+
+ x = self.block17(c3, x)
+ x = self.block18(x)
+
+ x = self.block19(c2, x)
+ x = self.block20(x)
+
+ x = self.block21(c1, x)
+ x = self.block22(x)
+ x = self.block23(x)
+ x = x[:, 7:, :, :]
+
+ return x
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py b/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/src/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ #[6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+
+ self.fpn_selected = [3, 6, 10]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ #if pretrained:
+ # self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c2, c3, c4 = fpn_features
+ return c2, c3, c4
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Tiny(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Tiny, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=True)
+
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+ out_c1= 64, out_c2=64)
+ self.block13 = BlockTypeB(128, 64)
+
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
+ out_c1= 32, out_c2= 32)
+ self.block15 = BlockTypeB(64, 64)
+
+ self.block16 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c2, c3, c4 = self.backbone(x)
+
+ x = self.block12(c3, c4)
+ x = self.block13(x)
+ x = self.block14(c2, x)
+ x = self.block15(x)
+ x = self.block16(x)
+ x = x[:, 7:, :, :]
+ #print(x.shape)
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+
+ return x
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/mlsd/utils.py b/src/custom_controlnet_aux/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..28071cbf129a2bedb21a44f95d565aef7974e583
--- /dev/null
+++ b/src/custom_controlnet_aux/mlsd/utils.py
@@ -0,0 +1,584 @@
+'''
+modified by lihaoweicv
+pytorch version
+'''
+
+'''
+M-LSD
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+'''
+
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn import functional as F
+
+
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+ '''
+ tpMap:
+ center: tpMap[1, 0, :, :]
+ displacement: tpMap[1, 1:5, :, :]
+ '''
+ b, c, h, w = tpMap.shape
+ assert b==1, 'only support bsize==1'
+ displacement = tpMap[:, 1:5, :, :][0]
+ center = tpMap[:, 0, :, :]
+ heat = torch.sigmoid(center)
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+ keep = (hmax == heat).float()
+ heat = heat * keep
+ heat = heat.reshape(-1, )
+
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ ptss = torch.cat((yy, xx),dim=-1)
+
+ ptss = ptss.detach().cpu().numpy()
+ scores = scores.detach().cpu().numpy()
+ displacement = displacement.detach().cpu().numpy()
+ displacement = displacement.transpose((1,2,0))
+ return ptss, scores, displacement
+
+
+def pred_lines(image, model,
+ input_shape=[512, 512],
+ score_thr=0.10,
+ dist_thr=20.0):
+ h, w, _ = image.shape
+
+ device = next(iter(model.parameters())).device
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+
+ resized_image = resized_image.transpose((2,0,1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float()
+ batch_image = batch_image.to(device)
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2]
+ end = vmap[:, :, 2:]
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ segments_list = []
+ for center, score in zip(pts, pts_score):
+ y, x = center
+ distance = dist_map[y, x]
+ if score > score_thr and distance > dist_thr:
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ x_start = x + disp_x_start
+ y_start = y + disp_y_start
+ x_end = x + disp_x_end
+ y_end = y + disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ lines = 2 * np.array(segments_list) # 256 > 512
+ lines[:, 0] = lines[:, 0] * w_ratio
+ lines[:, 1] = lines[:, 1] * h_ratio
+ lines[:, 2] = lines[:, 2] * w_ratio
+ lines[:, 3] = lines[:, 3] * h_ratio
+
+ return lines
+
+
+def pred_squares(image,
+ model,
+ input_shape=[512, 512],
+ params={'score': 0.06,
+ 'outside_ratio': 0.28,
+ 'inside_ratio': 0.45,
+ 'w_overlap': 0.0,
+ 'w_degree': 1.95,
+ 'w_length': 0.0,
+ 'w_area': 1.86,
+ 'w_center': 0.14}):
+ '''
+ shape = [height, width]
+ '''
+ h, w, _ = image.shape
+ original_shape = [h, w]
+ device = next(iter(model.parameters())).device
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2, 0, 1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().to(device)
+ outputs = model(batch_image)
+
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2] # (x, y)
+ end = vmap[:, :, 2:] # (x, y)
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ junc_list = []
+ segments_list = []
+ for junc, score in zip(pts, pts_score):
+ y, x = junc
+ distance = dist_map[y, x]
+ if score > params['score'] and distance > 20.0:
+ junc_list.append([x, y])
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ d_arrow = 1.0
+ x_start = x + d_arrow * disp_x_start
+ y_start = y + d_arrow * disp_y_start
+ x_end = x + d_arrow * disp_x_end
+ y_end = y + d_arrow * disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ segments = np.array(segments_list)
+
+ ####### post processing for squares
+ # 1. get unique lines
+ point = np.array([[0, 0]])
+ point = point[0]
+ start = segments[:, :2]
+ end = segments[:, 2:]
+ diff = start - end
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+ theta[theta < 0.0] += 180
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+
+ d_quant = 1
+ theta_quant = 2
+ hough[:, 0] //= d_quant
+ hough[:, 1] //= theta_quant
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+ yx_indices = hough[indices, :].astype('int32')
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+
+ acc_map_np = acc_map
+ # acc_map = acc_map[None, :, :, None]
+ #
+ # ### fast suppression using tensorflow op
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+ # _, h, w, _ = acc_map.shape
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
+ # yx = tf.concat([y, x], axis=-1)
+
+ ### fast suppression using pytorch op
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+ _,_, h, w = acc_map.shape
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+ flatten_acc_map = acc_map.reshape([-1, ])
+
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ yx = torch.cat((yy, xx), dim=-1)
+
+ yx = yx.detach().cpu().numpy()
+
+ topk_values = scores.detach().cpu().numpy()
+ indices = idx_map[yx[:, 0], yx[:, 1]]
+ basis = 5 // 2
+
+ merged_segments = []
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+ y, x = yx_pt
+ if max_indice == -1 or value == 0:
+ continue
+ segment_list = []
+ for y_offset in range(-basis, basis + 1):
+ for x_offset in range(-basis, basis + 1):
+ indice = idx_map[y + y_offset, x + x_offset]
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
+ if indice != -1:
+ segment_list.append(segments[indice])
+ if cnt > 1:
+ check_cnt = 1
+ current_hough = hough[indice]
+ for new_indice, new_hough in enumerate(hough):
+ if (current_hough == new_hough).all() and indice != new_indice:
+ segment_list.append(segments[new_indice])
+ check_cnt += 1
+ if check_cnt == cnt:
+ break
+ group_segments = np.array(segment_list).reshape([-1, 2])
+ sorted_group_segments = np.sort(group_segments, axis=0)
+ x_min, y_min = sorted_group_segments[0, :]
+ x_max, y_max = sorted_group_segments[-1, :]
+
+ deg = theta[max_indice]
+ if deg >= 90:
+ merged_segments.append([x_min, y_max, x_max, y_min])
+ else:
+ merged_segments.append([x_min, y_min, x_max, y_max])
+
+ # 2. get intersections
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
+ start = new_segments[:, :2] # (x1, y1)
+ end = new_segments[:, 2:] # (x2, y2)
+ new_centers = (start + end) / 2.0
+ diff = start - end
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+
+ # ax + by = c
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ pre_det = a[:, None] * b[None, :]
+ det = pre_det - np.transpose(pre_det)
+
+ pre_inter_y = a[:, None] * c[None, :]
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+ pre_inter_x = c[:, None] * b[None, :]
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+
+ # 3. get corner information
+ # 3.1 get distance
+ '''
+ dist_segments:
+ | dist(0), dist(1), dist(2), ...|
+ dist_inter_to_segment1:
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+ ...
+ dist_inter_to_semgnet2:
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ ...
+ '''
+
+ dist_inter_to_segment1_start = np.sqrt(
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment1_end = np.sqrt(
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_start = np.sqrt(
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_end = np.sqrt(
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+
+ # sort ascending
+ dist_inter_to_segment1 = np.sort(
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ dist_inter_to_segment2 = np.sort(
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+
+ # 3.2 get degree
+ inter_to_start = new_centers[:, None, :] - inter_pts
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+ inter_to_end = new_centers[None, :, :] - inter_pts
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+
+ '''
+ B -- G
+ | |
+ C -- R
+ B : blue / G: green / C: cyan / R: red
+
+ 0 -- 1
+ | |
+ 3 -- 2
+ '''
+ # rename variables
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+ # sort deg ascending
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+
+ deg_diff_map = np.abs(deg1_map - deg2_map)
+ # we only consider the smallest degree of intersect
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+
+ # define available degree range
+ deg_range = [60, 120]
+
+ corner_dict = {corner_info: [] for corner_info in range(4)}
+ inter_points = []
+ for i in range(inter_pts.shape[0]):
+ for j in range(i + 1, inter_pts.shape[1]):
+ # i, j > line index, always i < j
+ x, y = inter_pts[i, j, :]
+ deg1, deg2 = deg_sort[i, j, :]
+ deg_diff = deg_diff_map[i, j]
+
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+
+ if check_degree and check_distance:
+ corner_info = None
+
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+ corner_info, color_info = 0, 'blue'
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+ corner_info, color_info = 1, 'green'
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+ corner_info, color_info = 2, 'black'
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+ corner_info, color_info = 3, 'cyan'
+ else:
+ corner_info, color_info = 4, 'red' # we don't use it
+ continue
+
+ corner_dict[corner_info].append([x, y, i, j])
+ inter_points.append([x, y])
+
+ square_list = []
+ connect_list = []
+ segments_list = []
+ for corner0 in corner_dict[0]:
+ for corner1 in corner_dict[1]:
+ connect01 = False
+ for corner0_line in corner0[2:]:
+ if corner0_line in corner1[2:]:
+ connect01 = True
+ break
+ if connect01:
+ for corner2 in corner_dict[2]:
+ connect12 = False
+ for corner1_line in corner1[2:]:
+ if corner1_line in corner2[2:]:
+ connect12 = True
+ break
+ if connect12:
+ for corner3 in corner_dict[3]:
+ connect23 = False
+ for corner2_line in corner2[2:]:
+ if corner2_line in corner3[2:]:
+ connect23 = True
+ break
+ if connect23:
+ for corner3_line in corner3[2:]:
+ if corner3_line in corner0[2:]:
+ # SQUARE!!!
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ square_list:
+ order: 0 > 1 > 2 > 3
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ ...
+ connect_list:
+ order: 01 > 12 > 23 > 30
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ ...
+ segments_list:
+ order: 0 > 1 > 2 > 3
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ ...
+ '''
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+
+ def check_outside_inside(segments_info, connect_idx):
+ # return 'outside or inside', min distance, cover_param, peri_param
+ if connect_idx == segments_info[0]:
+ check_dist_mat = dist_inter_to_segment1
+ else:
+ check_dist_mat = dist_inter_to_segment2
+
+ i, j = segments_info
+ min_dist, max_dist = check_dist_mat[i, j, :]
+ connect_dist = dist_segments[connect_idx]
+ if max_dist > connect_dist:
+ return 'outside', min_dist, 0, 1
+ else:
+ return 'inside', min_dist, -1, -1
+
+ top_square = None
+
+ try:
+ map_size = input_shape[0] / 2
+ squares = np.array(square_list).reshape([-1, 4, 2])
+ score_array = []
+ connect_array = np.array(connect_list)
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
+
+ # get degree of corners:
+ squares_rollup = np.roll(squares, 1, axis=1)
+ squares_rolldown = np.roll(squares, -1, axis=1)
+ vec1 = squares_rollup - squares
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+ vec2 = squares_rolldown - squares
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
+
+ # get square score
+ overlap_scores = []
+ degree_scores = []
+ length_scores = []
+
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+
+ # segments: [4, 2]
+ # connects: [4]
+ '''
+
+ ###################################### OVERLAP SCORES
+ cover = 0
+ perimeter = 0
+ # check 0 > 1 > 2 > 3
+ square_length = []
+
+ for start_idx in range(4):
+ end_idx = (start_idx + 1) % 4
+
+ connect_idx = connects[start_idx] # segment idx of segment01
+ start_segments = segments[start_idx]
+ end_segments = segments[end_idx]
+
+ start_point = square[start_idx]
+ end_point = square[end_idx]
+
+ # check whether outside or inside
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+ connect_idx)
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+
+ square_length.append(
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+
+ overlap_scores.append(cover / perimeter)
+ ######################################
+ ###################################### DEGREE SCORES
+ '''
+ deg0 vs deg2
+ deg1 vs deg3
+ '''
+ deg0, deg1, deg2, deg3 = degree
+ deg_ratio1 = deg0 / deg2
+ if deg_ratio1 > 1.0:
+ deg_ratio1 = 1 / deg_ratio1
+ deg_ratio2 = deg1 / deg3
+ if deg_ratio2 > 1.0:
+ deg_ratio2 = 1 / deg_ratio2
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+ ######################################
+ ###################################### LENGTH SCORES
+ '''
+ len0 vs len2
+ len1 vs len3
+ '''
+ len0, len1, len2, len3 = square_length
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
+
+ ######################################
+
+ overlap_scores = np.array(overlap_scores)
+ overlap_scores /= np.max(overlap_scores)
+
+ degree_scores = np.array(degree_scores)
+ # degree_scores /= np.max(degree_scores)
+
+ length_scores = np.array(length_scores)
+
+ ###################################### AREA SCORES
+ area_scores = np.reshape(squares, [-1, 4, 2])
+ area_x = area_scores[:, :, 0]
+ area_y = area_scores[:, :, 1]
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+ area_scores = 0.5 * np.abs(area_scores + correction)
+ area_scores /= (map_size * map_size) # np.max(area_scores)
+ ######################################
+
+ ###################################### CENTER SCORES
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
+ # squares: [n, 4, 2]
+ square_centers = np.mean(squares, axis=1) # [n, 2]
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+ center_scores = center2center / (map_size / np.sqrt(2.0))
+
+ '''
+ score_w = [overlap, degree, area, center, length]
+ '''
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+ score_array = params['w_overlap'] * overlap_scores \
+ + params['w_degree'] * degree_scores \
+ + params['w_area'] * area_scores \
+ - params['w_center'] * center_scores \
+ + params['w_length'] * length_scores
+
+ best_square = []
+
+ sorted_idx = np.argsort(score_array)[::-1]
+ score_array = score_array[sorted_idx]
+ squares = squares[sorted_idx]
+
+ except Exception as e:
+ pass
+
+ '''return list
+ merged_lines, squares, scores
+ '''
+
+ try:
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+ except:
+ new_segments = []
+
+ try:
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ squares = []
+ score_array = []
+
+ try:
+ inter_points = np.array(inter_points)
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ inter_points = []
+
+ return new_segments, squares, score_array, inter_points
diff --git a/src/custom_controlnet_aux/normalbae/LICENSE b/src/custom_controlnet_aux/normalbae/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/normalbae/__init__.py b/src/custom_controlnet_aux/normalbae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2f54abbdad2a75e4870fff21346ad2265a53ad
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/__init__.py
@@ -0,0 +1,85 @@
+import os
+import types
+import warnings
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+from .nets.NNET import NNET
+
+
+# load model
+def load_checkpoint(fpath, model):
+ ckpt = torch.load(fpath, map_location='cpu')['model']
+
+ load_dict = {}
+ for k, v in ckpt.items():
+ if k.startswith('module.'):
+ k_ = k.replace('module.', '')
+ load_dict[k_] = v
+ else:
+ load_dict[k] = v
+
+ model.load_state_dict(load_dict)
+ return model
+
+class NormalBaeDetector:
+ def __init__(self, model):
+ self.model = model
+ self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="scannet.pt"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ args = types.SimpleNamespace()
+ args.mode = 'client'
+ args.architecture = 'BN'
+ args.pretrained = 'scannet'
+ args.sampling_ratio = 0.4
+ args.importance_ratio = 0.7
+ model = NNET(args)
+ model = load_checkpoint(model_path, model)
+ model.eval()
+
+ return cls(model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ image_normal = detected_map
+ with torch.no_grad():
+ image_normal = torch.from_numpy(image_normal).float().to(self.device)
+ image_normal = image_normal / 255.0
+ image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
+ image_normal = self.norm(image_normal)
+
+ normal = self.model(image_normal)
+ normal = normal[0][-1][:, :3]
+ # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
+ # d = torch.maximum(d, torch.ones_like(d) * 1e-5)
+ # normal /= d
+ normal = ((normal + 1) * 0.5).clip(0, 1)
+
+ normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
+ normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = remove_pad(HWC3(normal_image))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/normalbae/nets/NNET.py b/src/custom_controlnet_aux/normalbae/nets/NNET.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ddbc50c3ac18aa4b7f16779fe3c0133981ecc7a
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/nets/NNET.py
@@ -0,0 +1,22 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .submodules.encoder import Encoder
+from .submodules.decoder import Decoder
+
+
+class NNET(nn.Module):
+ def __init__(self, args):
+ super(NNET, self).__init__()
+ self.encoder = Encoder()
+ self.decoder = Decoder(args)
+
+ def get_1x_lr_params(self): # lr/10 learning rate
+ return self.encoder.parameters()
+
+ def get_10x_lr_params(self): # lr learning rate
+ return self.decoder.parameters()
+
+ def forward(self, img, **kwargs):
+ return self.decoder(self.encoder(img), **kwargs)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/normalbae/nets/__init__.py b/src/custom_controlnet_aux/normalbae/nets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/normalbae/nets/submodules/__init__.py b/src/custom_controlnet_aux/normalbae/nets/submodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/normalbae/nets/submodules/decoder.py b/src/custom_controlnet_aux/normalbae/nets/submodules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3a192df6da9c356829a71be3358084bfd3f9ae
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/nets/submodules/decoder.py
@@ -0,0 +1,202 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
+
+
+class Decoder(nn.Module):
+ def __init__(self, args):
+ super(Decoder, self).__init__()
+
+ # hyper-parameter for sampling
+ self.sampling_ratio = args.sampling_ratio
+ self.importance_ratio = args.importance_ratio
+
+ # feature-map
+ self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
+ if args.architecture == 'BN':
+ self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
+ self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
+ self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
+ self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
+
+ elif args.architecture == 'GN':
+ self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
+ self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
+ self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
+ self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
+
+ else:
+ raise Exception('invalid architecture')
+
+ # produces 1/8 res output
+ self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+
+ # produces 1/4 res output
+ self.out_conv_res4 = nn.Sequential(
+ nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ # produces 1/2 res output
+ self.out_conv_res2 = nn.Sequential(
+ nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ # produces 1/1 res output
+ self.out_conv_res1 = nn.Sequential(
+ nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ def forward(self, features, gt_norm_mask=None, mode='test'):
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[5], features[7], features[10]
+
+ # generate feature-map
+
+ x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
+ x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
+ x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
+ x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
+ x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
+
+ # 1/8 res output
+ out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
+ out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
+
+ ################################################################################################################
+ # out_res4
+ ################################################################################################################
+
+ if mode == 'train':
+ # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
+ out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res8_res4.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res4 = out_res8_res4
+
+ # grid_sample feature-map
+ feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
+ init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ # try all pixels
+ out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
+ out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
+ out_res4 = out_res4.view(B, 4, H, W)
+ samples_pred_res4 = point_coords_res4 = None
+
+ ################################################################################################################
+ # out_res2
+ ################################################################################################################
+
+ if mode == 'train':
+
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
+ out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res4_res2.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res2 = out_res4_res2
+
+ # grid_sample feature-map
+ feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
+ init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
+ out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
+ out_res2 = out_res2.view(B, 4, H, W)
+ samples_pred_res2 = point_coords_res2 = None
+
+ ################################################################################################################
+ # out_res1
+ ################################################################################################################
+
+ if mode == 'train':
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
+ out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res2_res1.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res1 = out_res2_res1
+
+ # grid_sample feature-map
+ feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
+ init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
+ out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
+ out_res1 = out_res1.view(B, 4, H, W)
+ samples_pred_res1 = point_coords_res1 = None
+
+ return [out_res8, out_res4, out_res2, out_res1], \
+ [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
+ [None, point_coords_res4, point_coords_res2, point_coords_res1]
+
diff --git a/src/custom_controlnet_aux/normalbae/nets/submodules/encoder.py b/src/custom_controlnet_aux/normalbae/nets/submodules/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..51371bc51e2ef8f0abb55a644e792ac269021563
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/nets/submodules/encoder.py
@@ -0,0 +1,30 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+
+ basemodel_name = 'tf_efficientnet_b5.ap_in1k'
+ print('Loading base model ()...'.format(basemodel_name), end='')
+ import timm
+ basemodel = timm.create_model(basemodel_name, pretrained=False, num_classes=0)
+ print('Done.')
+
+
+ self.original_model = basemodel
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if (k == 'blocks'):
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
+
+
diff --git a/src/custom_controlnet_aux/normalbae/nets/submodules/submodules.py b/src/custom_controlnet_aux/normalbae/nets/submodules/submodules.py
new file mode 100644
index 0000000000000000000000000000000000000000..409733351bd6ab5d191c800aff1bc05bfa4cb6f8
--- /dev/null
+++ b/src/custom_controlnet_aux/normalbae/nets/submodules/submodules.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+########################################################################################################################
+
+
+# Upsample + BatchNorm
+class UpSampleBN(nn.Module):
+ def __init__(self, skip_input, output_features):
+ super(UpSampleBN, self).__init__()
+
+ self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU(),
+ nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU())
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+# Upsample + GroupNorm + Weight Standardization
+class UpSampleGN(nn.Module):
+ def __init__(self, skip_input, output_features):
+ super(UpSampleGN, self).__init__()
+
+ self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU(),
+ Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU())
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+# Conv2d with weight standardization
+class Conv2d(nn.Conv2d):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, x):
+ weight = self.weight
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
+ keepdim=True).mean(dim=3, keepdim=True)
+ weight = weight - weight_mean
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+ weight = weight / std.expand_as(weight)
+ return F.conv2d(x, weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+# normalize
+def norm_normalize(norm_out):
+ min_kappa = 0.01
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
+ kappa = F.elu(kappa) + 1.0 + min_kappa
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
+ return final_out
+
+
+# uncertainty-guided sampling (only used during training)
+@torch.no_grad()
+def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # uncertainty map
+ uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
+ uncertainty_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
+
+ # importance sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ rows_float = rows_int / float(H-1) # 0 to 1.0
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ cols_float = cols_int / float(W-1) # 0 to 1.0
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ point_coords = torch.zeros(B, 1, N, 2)
+ point_coords[:, 0, :, 0] = cols_float # x coord
+ point_coords[:, 0, :, 1] = rows_float # y coord
+ point_coords = point_coords.to(device)
+ return point_coords, rows_int, cols_int
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/oneformer/__init__.py b/src/custom_controlnet_aux/oneformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..764c78a242205f3df60fec8a6ddf39f93ecb695a
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/__init__.py
@@ -0,0 +1,4 @@
+"""
+OneFormer implementation using HuggingFace transformers for PyTorch 2.7 compatibility.
+"""
+from .transformers import OneformerSegmentor
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml b/src/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..31eab45b878433fc844a13dbdd54f97c936d9b89
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml
@@ -0,0 +1,68 @@
+MODEL:
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "build_resnet_backbone"
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ RESNETS:
+ DEPTH: 50
+ STEM_TYPE: "basic" # not used
+ STEM_OUT_CHANNELS: 64
+ STRIDE_IN_1X1: False
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ # NORM: "SyncBN"
+ RES5_MULTI_GRID: [1, 1, 1] # not used
+DATASETS:
+ TRAIN: ("ade20k_panoptic_train",)
+ TEST_PANOPTIC: ("ade20k_panoptic_val",)
+ TEST_INSTANCE: ("ade20k_instance_val",)
+ TEST_SEMANTIC: ("ade20k_sem_seg_val",)
+SOLVER:
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.0001
+ MAX_ITER: 160000
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 0
+ WEIGHT_DECAY: 0.05
+ OPTIMIZER: "ADAMW"
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
+ BACKBONE_MULTIPLIER: 0.1
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 0.01
+ NORM_TYPE: 2.0
+ AMP:
+ ENABLED: True
+INPUT:
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 512
+ MAX_SIZE_TRAIN: 2048
+ MAX_SIZE_TEST: 2048
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (512, 512)
+ SINGLE_CATEGORY_MAX_AREA: 1.0
+ COLOR_AUG_SSD: True
+ SIZE_DIVISIBILITY: 512 # used in dataset mapper
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "oneformer_unified"
+ MAX_SEQ_LEN: 77
+ TASK_SEQ_LEN: 77
+ TASK_PROB:
+ SEMANTIC: 0.33
+ INSTANCE: 0.66
+TEST:
+ EVAL_PERIOD: 5000
+ AUG:
+ ENABLED: False
+ MIN_SIZES: [256, 384, 512, 640, 768, 896]
+ MAX_SIZE: 3584
+ FLIP: True
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: True
+ NUM_WORKERS: 4
+VERSION: 2
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml b/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..770ffc81907f8d7c7520e079b1c46060707254b8
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml
@@ -0,0 +1,58 @@
+_BASE_: Base-ADE20K-UnifiedSegmentation.yaml
+MODEL:
+ META_ARCHITECTURE: "OneFormer"
+ SEM_SEG_HEAD:
+ NAME: "OneFormerHead"
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 150
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 256
+ MASK_DIM: 256
+ NORM: "GN"
+ # pixel decoder
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ ONE_FORMER:
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ CONTRASTIVE_WEIGHT: 0.5
+ CONTRASTIVE_TEMPERATURE: 0.07
+ HIDDEN_DIM: 256
+ NUM_OBJECT_QUERIES: 150
+ USE_TASK_NORM: True
+ NHEADS: 8
+ DROPOUT: 0.1
+ DIM_FEEDFORWARD: 2048
+ ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ CLASS_DEC_LAYERS: 2
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ TEXT_ENCODER:
+ WIDTH: 256
+ CONTEXT_LENGTH: 77
+ NUM_LAYERS: 6
+ VOCAB_SIZE: 49408
+ PROJ_NUM_LAYERS: 2
+ N_CTX: 16
+ TEST:
+ SEMANTIC_ON: True
+ INSTANCE_ON: True
+ PANOPTIC_ON: True
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.8
+ TASK: "panoptic"
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml b/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..69c44ade144e4504077c0fe04fa8bb3491a679ed
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml
@@ -0,0 +1,40 @@
+_BASE_: oneformer_R50_bs16_160k.yaml
+MODEL:
+ BACKBONE:
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ NUM_HEADS: [6, 12, 24, 48]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ ONE_FORMER:
+ NUM_OBJECT_QUERIES: 250
+INPUT:
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TRAIN: 2560
+ MAX_SIZE_TEST: 2560
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (640, 640)
+ SINGLE_CATEGORY_MAX_AREA: 1.0
+ COLOR_AUG_SSD: True
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
+ FORMAT: "RGB"
+TEST:
+ DETECTIONS_PER_IMAGE: 250
+ EVAL_PERIOD: 5000
+ AUG:
+ ENABLED: False
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
+ MAX_SIZE: 4480
+ FLIP: True
diff --git a/src/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml b/src/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ccd24f348f9bc7d60dcdc4b74d887708e57cb8a8
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml
@@ -0,0 +1,54 @@
+MODEL:
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "build_resnet_backbone"
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ RESNETS:
+ DEPTH: 50
+ STEM_TYPE: "basic" # not used
+ STEM_OUT_CHANNELS: 64
+ STRIDE_IN_1X1: False
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ # NORM: "SyncBN"
+ RES5_MULTI_GRID: [1, 1, 1] # not used
+DATASETS:
+ TRAIN: ("coco_2017_train_panoptic_with_sem_seg",)
+ TEST_PANOPTIC: ("coco_2017_val_panoptic_with_sem_seg",) # to evaluate instance and semantic performance as well
+ TEST_INSTANCE: ("coco_2017_val",)
+ TEST_SEMANTIC: ("coco_2017_val_panoptic_with_sem_seg",)
+SOLVER:
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.0001
+ STEPS: (327778, 355092)
+ MAX_ITER: 368750
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 10
+ WEIGHT_DECAY: 0.05
+ OPTIMIZER: "ADAMW"
+ BACKBONE_MULTIPLIER: 0.1
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 0.01
+ NORM_TYPE: 2.0
+ AMP:
+ ENABLED: True
+INPUT:
+ IMAGE_SIZE: 1024
+ MIN_SCALE: 0.1
+ MAX_SCALE: 2.0
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "coco_unified_lsj"
+ MAX_SEQ_LEN: 77
+ TASK_SEQ_LEN: 77
+ TASK_PROB:
+ SEMANTIC: 0.33
+ INSTANCE: 0.66
+TEST:
+ EVAL_PERIOD: 5000
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: True
+ NUM_WORKERS: 4
+VERSION: 2
diff --git a/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml b/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f768c8fa8b5e4fc1121e65e050053e0d8870cd73
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml
@@ -0,0 +1,59 @@
+_BASE_: Base-COCO-UnifiedSegmentation.yaml
+MODEL:
+ META_ARCHITECTURE: "OneFormer"
+ SEM_SEG_HEAD:
+ NAME: "OneFormerHead"
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 133
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 256
+ MASK_DIM: 256
+ NORM: "GN"
+ # pixel decoder
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ ONE_FORMER:
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ CONTRASTIVE_WEIGHT: 0.5
+ CONTRASTIVE_TEMPERATURE: 0.07
+ HIDDEN_DIM: 256
+ NUM_OBJECT_QUERIES: 150
+ USE_TASK_NORM: True
+ NHEADS: 8
+ DROPOUT: 0.1
+ DIM_FEEDFORWARD: 2048
+ ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ CLASS_DEC_LAYERS: 2
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ TEXT_ENCODER:
+ WIDTH: 256
+ CONTEXT_LENGTH: 77
+ NUM_LAYERS: 6
+ VOCAB_SIZE: 49408
+ PROJ_NUM_LAYERS: 2
+ N_CTX: 16
+ TEST:
+ SEMANTIC_ON: True
+ INSTANCE_ON: True
+ PANOPTIC_ON: True
+ DETECTION_ON: False
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.8
+ TASK: "panoptic"
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml b/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..faae655317c52d90b9f756417f8b1a1adcbe78f2
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml
@@ -0,0 +1,25 @@
+_BASE_: oneformer_R50_bs16_50ep.yaml
+MODEL:
+ BACKBONE:
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ NUM_HEADS: [6, 12, 24, 48]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ ONE_FORMER:
+ NUM_OBJECT_QUERIES: 150
+SOLVER:
+ STEPS: (655556, 735184)
+ MAX_ITER: 737500
+ AMP:
+ ENABLED: False
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/src/custom_controlnet_aux/oneformer/transformers.py b/src/custom_controlnet_aux/oneformer/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eef7d04461d614befe7c366dfbeb388707365b6
--- /dev/null
+++ b/src/custom_controlnet_aux/oneformer/transformers.py
@@ -0,0 +1,124 @@
+"""
+OneFormer implementation using HuggingFace transformers for PyTorch 2.7 compatibility.
+Provides equivalent functionality to the original detectron2 implementation.
+"""
+import numpy as np
+import cv2
+import torch
+from PIL import Image
+
+# Import utilities
+from ..util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+
+
+class OneformerSegmentor:
+ """
+ OneFormer segmentation using HuggingFace transformers implementation.
+
+ Uses equivalent models that are PyTorch 2.7 compatible and actively maintained:
+ - Same architecture (OneFormer with Swin-Large backbone)
+ - Same training datasets (COCO panoptic / ADE20K)
+ - Professional colorized visualization output
+ """
+
+ def __init__(self, model_name):
+ """Initialize OneFormer with HuggingFace transformers implementation."""
+ from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
+
+ self.model_name = model_name
+ self.processor = OneFormerProcessor.from_pretrained(model_name)
+ self.model = OneFormerForUniversalSegmentation.from_pretrained(model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="250_16_swin_l_oneformer_ade20k_160k.pth", config_path=None):
+ """Create OneFormer model from pretrained weights."""
+ model_mapping = {
+ "250_16_swin_l_oneformer_ade20k_160k.pth": "shi-labs/oneformer_ade20k_swin_large",
+ "150_16_swin_l_oneformer_coco_100ep.pth": "shi-labs/oneformer_coco_swin_large"
+ }
+
+ if filename in model_mapping:
+ model_name = model_mapping[filename]
+ elif "coco" in filename.lower():
+ model_name = "shi-labs/oneformer_coco_swin_large"
+ else:
+ model_name = "shi-labs/oneformer_ade20k_swin_large"
+
+ return cls(model_name)
+
+ def to(self, device):
+ """Move model to specified device."""
+ self.model = self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ """Process image for semantic segmentation."""
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ # Convert to PIL for processing
+ if isinstance(input_image, np.ndarray):
+ pil_image = Image.fromarray(input_image)
+ else:
+ pil_image = input_image
+
+ # Process with HuggingFace pipeline
+ semantic_inputs = self.processor(
+ images=pil_image,
+ task_inputs=["semantic"],
+ return_tensors="pt"
+ ).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.model(**semantic_inputs)
+
+ # Post-process results
+ predicted_semantic_map = self.processor.post_process_semantic_segmentation(
+ outputs, target_sizes=[pil_image.size[::-1]]
+ )[0]
+
+ # Convert to colormap using professional color scheme
+ seg_map = predicted_semantic_map.cpu().numpy().astype(np.uint8)
+ detected_map = self._generate_professional_colormap(seg_map)
+ detected_map = remove_pad(HWC3(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+ def _generate_professional_colormap(self, seg_map):
+ """Generate professional colormap for segmentation visualization."""
+ height, width = seg_map.shape
+ color_map = np.zeros((height, width, 3), dtype=np.uint8)
+
+ max_possible_classes = 200
+ colors = self._generate_detectron2_style_palette(max_possible_classes)
+
+ unique_classes = np.unique(seg_map)
+ for class_id in unique_classes:
+ mask = seg_map == class_id
+ color_map[mask] = colors[class_id % len(colors)]
+
+ return color_map
+
+ def _generate_detectron2_style_palette(self, num_classes):
+ """Generate professional color palette with good visual separation."""
+ colors = np.zeros((num_classes, 3), dtype=np.uint8)
+
+ colors[0] = [0, 0, 0] # Background is black
+
+ for i in range(1, num_classes):
+ hue = (i * 137.508) % 360 # Golden angle for good distribution
+ saturation = 0.6 + 0.4 * ((i % 3) / 2) # 60-100% saturation
+ value = 0.7 + 0.3 * ((i % 2)) # 70-100% value
+
+ color_hsv = np.array([[[hue / 2, saturation * 255, value * 255]]], dtype=np.uint8)
+ color_rgb = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2RGB)[0, 0]
+ colors[i] = color_rgb
+
+ return colors
+
+
diff --git a/src/custom_controlnet_aux/open_pose/LICENSE b/src/custom_controlnet_aux/open_pose/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/LICENSE
@@ -0,0 +1,108 @@
+OPENPOSE: MULTIPERSON KEYPOINT DETECTION
+SOFTWARE LICENSE AGREEMENT
+ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
+
+BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
+
+This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
+
+RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
+Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
+non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
+
+CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
+
+COPYRIGHT: The Software is owned by Licensor and is protected by United
+States copyright laws and applicable international treaties and/or conventions.
+
+PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
+
+DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
+
+BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
+
+USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
+
+You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
+
+ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
+
+TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
+
+The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
+
+FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
+
+DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
+
+SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
+
+EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
+
+EXPORT REGULATION: Licensee agrees to comply with any and all applicable
+U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
+
+SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
+
+NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
+
+GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
+
+ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
+
+
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
+
+1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014-2017, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright over
+their contributions to Caffe. The project versioning records all such
+contribution and copyright details. If a contributor wants to further mark
+their specific copyright on a particular contribution, they should indicate
+their copyright solely in the commit message of the change when it is
+committed.
+
+LICENSE
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+CONTRIBUTION AGREEMENT
+
+By contributing to the BVLC/caffe repository through pull-request, comment,
+or otherwise, the contributor releases their content to the
+license and copyright terms herein.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/open_pose/__init__.py b/src/custom_controlnet_aux/open_pose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d0f53e355b9de99d7aed6aeab903a99428a93c
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/__init__.py
@@ -0,0 +1,238 @@
+# Openpose
+# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
+# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
+# 3rd Edited by ControlNet
+# 4th Edited by ControlNet (added face and correct hands)
+# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
+# This preprocessor is licensed by CMU for non-commercial use only.
+
+
+import os
+
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import json
+import warnings
+from typing import Callable, List, NamedTuple, Tuple, Union
+
+import cv2
+import numpy as np
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+from . import util
+from .body import Body, BodyResult, Keypoint
+from .face import Face
+from .hand import Hand
+
+HandResult = List[Keypoint]
+FaceResult = List[Keypoint]
+
+class PoseResult(NamedTuple):
+ body: BodyResult
+ left_hand: Union[HandResult, None]
+ right_hand: Union[HandResult, None]
+ face: Union[FaceResult, None]
+
+def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True, xinsr_stick_scaling=False):
+ """
+ Draw the detected poses on an empty canvas.
+
+ Args:
+ poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
+ H (int): The height of the canvas.
+ W (int): The width of the canvas.
+ draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
+ draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
+ draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
+
+ Returns:
+ numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
+ """
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
+
+ for pose in poses:
+ if draw_body:
+ canvas = util.draw_bodypose(canvas, pose.body.keypoints, xinsr_stick_scaling)
+
+ if draw_hand:
+ canvas = util.draw_handpose(canvas, pose.left_hand)
+ canvas = util.draw_handpose(canvas, pose.right_hand)
+
+ if draw_face:
+ canvas = util.draw_facepose(canvas, pose.face)
+
+ return canvas
+
+def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
+ """ Encode the pose as a dict following openpose JSON output format:
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
+ """
+ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
+ if not keypoints:
+ return None
+
+ return [
+ value
+ for keypoint in keypoints
+ for value in (
+ [float(keypoint.x), float(keypoint.y), 1.0]
+ if keypoint is not None
+ else [0.0, 0.0, 0.0]
+ )
+ ]
+
+ return {
+ 'people': [
+ {
+ 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
+ "face_keypoints_2d": compress_keypoints(pose.face),
+ "hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
+ "hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
+ }
+ for pose in poses
+ ],
+ 'canvas_height': canvas_height,
+ 'canvas_width': canvas_width,
+ }
+
+class OpenposeDetector:
+ """
+ A class for detecting human poses in images using the Openpose model.
+
+ Attributes:
+ model_dir (str): Path to the directory where the pose models are stored.
+ """
+ def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
+ self.body_estimation = body_estimation
+ self.hand_estimation = hand_estimation
+ self.face_estimation = face_estimation
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="body_pose_model.pth", hand_filename="hand_pose_model.pth", face_filename="facenet.pth"):
+ if pretrained_model_or_path == "lllyasviel/ControlNet":
+ subfolder = "annotator/ckpts"
+ face_pretrained_model_or_path = "lllyasviel/Annotators"
+
+ else:
+ subfolder = ''
+ face_pretrained_model_or_path = pretrained_model_or_path
+
+ body_model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder)
+ hand_model_path = custom_hf_download(pretrained_model_or_path, hand_filename, subfolder=subfolder)
+ face_model_path = custom_hf_download(face_pretrained_model_or_path, face_filename, subfolder=subfolder)
+
+ body_estimation = Body(body_model_path)
+ hand_estimation = Hand(hand_model_path)
+ face_estimation = Face(face_model_path)
+
+ return cls(body_estimation, hand_estimation, face_estimation)
+
+ def to(self, device):
+ self.body_estimation.to(device)
+ self.hand_estimation.to(device)
+ self.face_estimation.to(device)
+ return self
+
+ def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
+ left_hand = None
+ right_hand = None
+ H, W, _ = oriImg.shape
+ for x, y, w, is_left in util.handDetect(body, oriImg):
+ peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32)
+ if peaks.ndim == 2 and peaks.shape[1] == 2:
+ peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
+ peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
+
+ hand_result = [
+ Keypoint(x=peak[0], y=peak[1])
+ for peak in peaks
+ ]
+
+ if is_left:
+ left_hand = hand_result
+ else:
+ right_hand = hand_result
+
+ return left_hand, right_hand
+
+ def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
+ face = util.faceDetect(body, oriImg)
+ if face is None:
+ return None
+
+ x, y, w = face
+ H, W, _ = oriImg.shape
+ heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :])
+ peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
+ if peaks.ndim == 2 and peaks.shape[1] == 2:
+ peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
+ peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
+ return [
+ Keypoint(x=peak[0], y=peak[1])
+ for peak in peaks
+ ]
+
+ return None
+
+ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
+ """
+ Detect poses in the given image.
+ Args:
+ oriImg (numpy.ndarray): The input image for pose detection.
+ include_hand (bool, optional): Whether to include hand detection. Defaults to False.
+ include_face (bool, optional): Whether to include face detection. Defaults to False.
+
+ Returns:
+ List[PoseResult]: A list of PoseResult objects containing the detected poses.
+ """
+ oriImg = oriImg[:, :, ::-1].copy()
+ H, W, C = oriImg.shape
+ with torch.no_grad():
+ candidate, subset = self.body_estimation(oriImg)
+ bodies = self.body_estimation.format_body_result(candidate, subset)
+
+ results = []
+ for body in bodies:
+ left_hand, right_hand, face = (None,) * 3
+ if include_hand:
+ left_hand, right_hand = self.detect_hands(body, oriImg)
+ if include_face:
+ face = self.detect_face(body, oriImg)
+
+ results.append(PoseResult(BodyResult(
+ keypoints=[
+ Keypoint(
+ x=keypoint.x / float(W),
+ y=keypoint.y / float(H)
+ ) if keypoint is not None else None
+ for keypoint in body.keypoints
+ ],
+ total_score=body.total_score,
+ total_parts=body.total_parts
+ ), left_hand, right_hand, face))
+
+ return results
+
+ def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", xinsr_stick_scaling=False, **kwargs):
+ if hand_and_face is not None:
+ warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning)
+ include_hand = hand_and_face
+ include_face = hand_and_face
+
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ poses = self.detect_poses(input_image, include_hand=include_hand, include_face=include_face)
+ canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face, xinsr_stick_scaling=xinsr_stick_scaling)
+ detected_map = HWC3(remove_pad(canvas))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ if image_and_json:
+ return (detected_map, encode_poses_as_dict(poses, detected_map.shape[0], detected_map.shape[1]))
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/open_pose/body.py b/src/custom_controlnet_aux/open_pose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd744869b603c4e544f7a522a375a6ac3794aa4
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/body.py
@@ -0,0 +1,278 @@
+import math
+from typing import List, NamedTuple, Union
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from scipy.ndimage.filters import gaussian_filter
+
+from . import util
+from .model import bodypose_model
+
+
+class Keypoint(NamedTuple):
+ x: float
+ y: float
+ score: float = 1.0
+ id: int = -1
+
+
+class BodyResult(NamedTuple):
+ # Note: Using `Union` instead of `|` operator as the ladder is a Python
+ # 3.10 feature.
+ # Annotator code should be Python 3.8 Compatible, as controlnet repo uses
+ # Python 3.8 environment.
+ # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
+ keypoints: List[Union[Keypoint, None]]
+ total_score: float
+ total_parts: int
+
+
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+ self.device = "cpu"
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ data = data.to(self.device)
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = util.smart_resize_k(paf, fx=stride, fy=stride)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+
+ @staticmethod
+ def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
+ """
+ Format the body results from the candidate and subset arrays into a list of BodyResult objects.
+
+ Args:
+ candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
+ for each body part.
+ subset (np.ndarray): An array of subsets containing indices to the candidate array for each
+ person detected. The last two columns of each row hold the total score and total parts
+ of the person.
+
+ Returns:
+ List[BodyResult]: A list of BodyResult objects, where each object represents a person with
+ detected keypoints, total score, and total parts.
+ """
+ return [
+ BodyResult(
+ keypoints=[
+ Keypoint(
+ x=candidate[candidate_index][0],
+ y=candidate[candidate_index][1],
+ score=candidate[candidate_index][2],
+ id=candidate[candidate_index][3]
+ ) if candidate_index != -1 else None
+ for candidate_index in person[:18].astype(int)
+ ],
+ total_score=person[18],
+ total_parts=person[19]
+ )
+ for person in subset
+ ]
+
+
+if __name__ == "__main__":
+ body_estimation = Body('../model/body_pose_model.pth')
+
+ test_image = '../images/ski.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ candidate, subset = body_estimation(oriImg)
+ bodies = body_estimation.format_body_result(candidate, subset)
+
+ canvas = oriImg
+ for body in bodies:
+ canvas = util.draw_bodypose(canvas, body)
+
+ plt.imshow(canvas[:, :, [2, 1, 0]])
+ plt.show()
diff --git a/src/custom_controlnet_aux/open_pose/face.py b/src/custom_controlnet_aux/open_pose/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..44e9ee235931bf67582cde56c6b525cb261060d2
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/face.py
@@ -0,0 +1,365 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
+from torchvision.transforms import ToPILImage, ToTensor
+
+from . import util
+
+
+class FaceNet(Module):
+ """Model the cascading heatmaps. """
+ def __init__(self):
+ super(FaceNet, self).__init__()
+ # cnn to make feature map
+ self.relu = ReLU()
+ self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
+ self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
+ kernel_size=3, stride=1, padding=1)
+ self.conv1_2 = Conv2d(
+ in_channels=64, out_channels=64, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_1 = Conv2d(
+ in_channels=64, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_1 = Conv2d(
+ in_channels=128, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_2 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_3 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_4 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_1 = Conv2d(
+ in_channels=256, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_3 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_4 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_1 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_3_CPM = Conv2d(
+ in_channels=512, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+
+ # stage1
+ self.conv6_1_CPM = Conv2d(
+ in_channels=128, out_channels=512, kernel_size=1, stride=1,
+ padding=0)
+ self.conv6_2_CPM = Conv2d(
+ in_channels=512, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage2
+ self.Mconv1_stage2 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage2 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage3
+ self.Mconv1_stage3 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage3 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage4
+ self.Mconv1_stage4 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage4 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage5
+ self.Mconv1_stage5 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage5 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage6
+ self.Mconv1_stage6 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage6 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ for m in self.modules():
+ if isinstance(m, Conv2d):
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ """Return a list of heatmaps."""
+ heatmaps = []
+
+ h = self.relu(self.conv1_1(x))
+ h = self.relu(self.conv1_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv2_1(h))
+ h = self.relu(self.conv2_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv3_1(h))
+ h = self.relu(self.conv3_2(h))
+ h = self.relu(self.conv3_3(h))
+ h = self.relu(self.conv3_4(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv4_1(h))
+ h = self.relu(self.conv4_2(h))
+ h = self.relu(self.conv4_3(h))
+ h = self.relu(self.conv4_4(h))
+ h = self.relu(self.conv5_1(h))
+ h = self.relu(self.conv5_2(h))
+ h = self.relu(self.conv5_3_CPM(h))
+ feature_map = h
+
+ # stage1
+ h = self.relu(self.conv6_1_CPM(h))
+ h = self.conv6_2_CPM(h)
+ heatmaps.append(h)
+
+ # stage2
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage2(h))
+ h = self.relu(self.Mconv2_stage2(h))
+ h = self.relu(self.Mconv3_stage2(h))
+ h = self.relu(self.Mconv4_stage2(h))
+ h = self.relu(self.Mconv5_stage2(h))
+ h = self.relu(self.Mconv6_stage2(h))
+ h = self.Mconv7_stage2(h)
+ heatmaps.append(h)
+
+ # stage3
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage3(h))
+ h = self.relu(self.Mconv2_stage3(h))
+ h = self.relu(self.Mconv3_stage3(h))
+ h = self.relu(self.Mconv4_stage3(h))
+ h = self.relu(self.Mconv5_stage3(h))
+ h = self.relu(self.Mconv6_stage3(h))
+ h = self.Mconv7_stage3(h)
+ heatmaps.append(h)
+
+ # stage4
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage4(h))
+ h = self.relu(self.Mconv2_stage4(h))
+ h = self.relu(self.Mconv3_stage4(h))
+ h = self.relu(self.Mconv4_stage4(h))
+ h = self.relu(self.Mconv5_stage4(h))
+ h = self.relu(self.Mconv6_stage4(h))
+ h = self.Mconv7_stage4(h)
+ heatmaps.append(h)
+
+ # stage5
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage5(h))
+ h = self.relu(self.Mconv2_stage5(h))
+ h = self.relu(self.Mconv3_stage5(h))
+ h = self.relu(self.Mconv4_stage5(h))
+ h = self.relu(self.Mconv5_stage5(h))
+ h = self.relu(self.Mconv6_stage5(h))
+ h = self.Mconv7_stage5(h)
+ heatmaps.append(h)
+
+ # stage6
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage6(h))
+ h = self.relu(self.Mconv2_stage6(h))
+ h = self.relu(self.Mconv3_stage6(h))
+ h = self.relu(self.Mconv4_stage6(h))
+ h = self.relu(self.Mconv5_stage6(h))
+ h = self.relu(self.Mconv6_stage6(h))
+ h = self.Mconv7_stage6(h)
+ heatmaps.append(h)
+
+ return heatmaps
+
+
+LOG = logging.getLogger(__name__)
+TOTEN = ToTensor()
+TOPIL = ToPILImage()
+
+
+params = {
+ 'gaussian_sigma': 2.5,
+ 'inference_img_size': 736, # 368, 736, 1312
+ 'heatmap_peak_thresh': 0.1,
+ 'crop_scale': 1.5,
+ 'line_indices': [
+ [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
+ [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
+ [13, 14], [14, 15], [15, 16],
+ [17, 18], [18, 19], [19, 20], [20, 21],
+ [22, 23], [23, 24], [24, 25], [25, 26],
+ [27, 28], [28, 29], [29, 30],
+ [31, 32], [32, 33], [33, 34], [34, 35],
+ [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
+ [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
+ [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
+ [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
+ [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
+ [66, 67], [67, 60]
+ ],
+}
+
+
+class Face(object):
+ """
+ The OpenPose face landmark detector model.
+
+ Args:
+ inference_size: set the size of the inference image size, suggested:
+ 368, 736, 1312, default 736
+ gaussian_sigma: blur the heatmaps, default 2.5
+ heatmap_peak_thresh: return landmark if over threshold, default 0.1
+
+ """
+ def __init__(self, face_model_path,
+ inference_size=None,
+ gaussian_sigma=None,
+ heatmap_peak_thresh=None):
+ self.inference_size = inference_size or params["inference_img_size"]
+ self.sigma = gaussian_sigma or params['gaussian_sigma']
+ self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
+ self.model = FaceNet()
+ self.model.load_state_dict(torch.load(face_model_path))
+ self.model.eval()
+ self.device = "cpu"
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, face_img):
+ H, W, C = face_img.shape
+
+ w_size = 384
+ x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
+
+ x_data = x_data.to(self.device)
+
+ with torch.no_grad():
+ hs = self.model(x_data[None, ...])
+ heatmaps = F.interpolate(
+ hs[-1],
+ (H, W),
+ mode='bilinear', align_corners=True).cpu().numpy()[0]
+ return heatmaps
+
+ def compute_peaks_from_heatmaps(self, heatmaps):
+ all_peaks = []
+ for part in range(heatmaps.shape[0]):
+ map_ori = heatmaps[part].copy()
+ binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ continue
+
+ positions = np.where(binary > 0.5)
+ intensities = map_ori[positions]
+ mi = np.argmax(intensities)
+ y, x = positions[0][mi], positions[1][mi]
+ all_peaks.append([x, y])
+
+ return np.array(all_peaks)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/open_pose/hand.py b/src/custom_controlnet_aux/open_pose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba641875540e54045c599818337a9cfef961b431
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/hand.py
@@ -0,0 +1,91 @@
+import cv2
+import numpy as np
+import torch
+from scipy.ndimage.filters import gaussian_filter
+from skimage.measure import label
+
+from . import util
+from .model import handpose_model
+
+
+class Hand(object):
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+ self.device = "cpu"
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, oriImgRaw):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize for x in scale_search]
+
+ wsize = 128
+ heatmap_avg = np.zeros((wsize, wsize, 22))
+
+ Hr, Wr, Cr = oriImgRaw.shape
+
+ oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize(oriImg, (scale, scale))
+
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ data = data.to(self.device)
+
+ with torch.no_grad():
+ output = self.model(data).cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (wsize, wsize))
+
+ heatmap_avg += heatmap / len(multiplier)
+
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+
+ y, x = util.npmax(map_ori)
+ y = int(float(y) * float(Hr) / float(wsize))
+ x = int(float(x) * float(Wr) / float(wsize))
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
+
+if __name__ == "__main__":
+ hand_estimation = Hand('../model/hand_pose_model.pth')
+
+ # test_image = '../images/hand.jpg'
+ test_image = '../images/hand.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ peaks = hand_estimation(oriImg)
+ canvas = util.draw_handpose(oriImg, peaks, True)
+ cv2.imshow('', canvas)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/open_pose/model.py b/src/custom_controlnet_aux/open_pose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3d47268986f8018b2c75307a7725d364b175fe
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/model.py
@@ -0,0 +1,217 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+
+
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
diff --git a/src/custom_controlnet_aux/open_pose/util.py b/src/custom_controlnet_aux/open_pose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a14c9fbc7df9f47b047df6cc82b5e5ae0fc4e1
--- /dev/null
+++ b/src/custom_controlnet_aux/open_pose/util.py
@@ -0,0 +1,390 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+from typing import List, Tuple, Union
+
+from .body import BodyResult, Keypoint
+
+eps = 0.01
+
+
+def smart_resize(x, s):
+ Ht, Wt = s
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
+
+
+def smart_resize_k(x, fx, fy):
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ Ht, Wt = Ho * fy, Wo * fx
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+
+def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint], xinsr_stick_scaling: bool = False) -> np.ndarray:
+ """
+ Draw keypoints and limbs representing body pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
+ keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
+ xinsr_stick_scaling (bool): Whether or not scaling stick width for xinsr ControlNet
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ H, W, C = canvas.shape
+ stickwidth = 4
+ # Ref: https://huggingface.co/xinsir/controlnet-openpose-sdxl-1.0
+ max_side = max(H, W)
+ if xinsr_stick_scaling:
+ stick_scale = 1 if max_side < 500 else min(2 + (max_side // 1000), 7)
+ else:
+ stick_scale = 1
+
+ limbSeq = [
+ [2, 3], [2, 6], [3, 4], [4, 5],
+ [6, 7], [7, 8], [2, 9], [9, 10],
+ [10, 11], [2, 12], [12, 13], [13, 14],
+ [2, 1], [1, 15], [15, 17], [1, 16],
+ [16, 18],
+ ]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ for (k1_index, k2_index), color in zip(limbSeq, colors):
+ keypoint1 = keypoints[k1_index - 1]
+ keypoint2 = keypoints[k2_index - 1]
+
+ if keypoint1 is None or keypoint2 is None:
+ continue
+
+ Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
+ X = np.array([keypoint1.y, keypoint2.y]) * float(H)
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth*stick_scale), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
+
+ for keypoint, color in zip(keypoints, colors):
+ if keypoint is None:
+ continue
+
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
+
+ return canvas
+
+
+def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints and connections representing hand pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ H, W, C = canvas.shape
+
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for ie, (e1, e2) in enumerate(edges):
+ k1 = keypoints[e1]
+ k2 = keypoints[e2]
+ if k1 is None or k2 is None:
+ continue
+
+ x1 = int(k1.x * W)
+ y1 = int(k1.y * H)
+ x2 = int(k2.x * W)
+ y2 = int(k2.y * H)
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
+
+ for keypoint in keypoints:
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ return canvas
+
+
+def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints representing face pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ H, W, C = canvas.shape
+ for keypoint in keypoints:
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
+ return canvas
+
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
+ """
+ Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
+ corner of the bounding box, the width (height) of the bounding box, and
+ a boolean flag indicating whether the hand is a left hand (True) or a
+ right hand (False).
+
+ Notes:
+ - The width and height of the bounding boxes are equal since the network requires squared input.
+ - The minimum bounding box size is 20 pixels.
+ """
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ left_shoulder = keypoints[5]
+ left_elbow = keypoints[6]
+ left_wrist = keypoints[7]
+ right_shoulder = keypoints[2]
+ right_elbow = keypoints[3]
+ right_wrist = keypoints[4]
+
+ # if any of three not detected
+ has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
+ has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
+ if not (has_left or has_right):
+ return []
+
+ hands = []
+ #left hand
+ if has_left:
+ hands.append([
+ left_shoulder.x, left_shoulder.y,
+ left_elbow.x, left_elbow.y,
+ left_wrist.x, left_wrist.y,
+ True
+ ])
+ # right hand
+ if has_right:
+ hands.append([
+ right_shoulder.x, right_shoulder.y,
+ right_elbow.x, right_elbow.y,
+ right_wrist.x, right_wrist.y,
+ False
+ ])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append((int(x), int(y), int(width), is_left))
+
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+
+# Written by Lvmin
+def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
+ """
+ Detect the face in the input body pose keypoints and calculate the bounding box for the face.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
+ bounding box and the width (height) of the bounding box, or None if the
+ face is not detected or the bounding box width is less than 20 pixels.
+
+ Notes:
+ - The width and height of the bounding box are equal.
+ - The minimum bounding box size is 20 pixels.
+ """
+ # left right eye ear 14 15 16 17
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ head = keypoints[0]
+ left_eye = keypoints[14]
+ right_eye = keypoints[15]
+ left_ear = keypoints[16]
+ right_ear = keypoints[17]
+
+ if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
+ return None
+
+ width = 0.0
+ x0, y0 = head.x, head.y
+
+ if left_eye is not None:
+ x1, y1 = left_eye.x, left_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if right_eye is not None:
+ x1, y1 = right_eye.x, right_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if left_ear is not None:
+ x1, y1 = left_ear.x, left_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ if right_ear is not None:
+ x1, y1 = right_ear.x, right_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ x, y = x0, y0
+
+ x -= width
+ y -= width
+
+ if x < 0:
+ x = 0
+
+ if y < 0:
+ y = 0
+
+ width1 = width * 2
+ width2 = width * 2
+
+ if x + width > image_width:
+ width1 = image_width - x
+
+ if y + width > image_height:
+ width2 = image_height - y
+
+ width = min(width1, width2)
+
+ if width >= 20:
+ return int(x), int(y), int(width)
+ else:
+ return None
+
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/pidi/LICENSE b/src/custom_controlnet_aux/pidi/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..913b6cf92c19d37b6ee4f7bc99c65f655e7f840c
--- /dev/null
+++ b/src/custom_controlnet_aux/pidi/LICENSE
@@ -0,0 +1,21 @@
+It is just for research purpose, and commercial use should be contacted with authors first.
+
+Copyright (c) 2021 Zhuo Su
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/pidi/__init__.py b/src/custom_controlnet_aux/pidi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81fb25ac6bd2a2341c0d7163ab4c08f971820193
--- /dev/null
+++ b/src/custom_controlnet_aux/pidi/__init__.py
@@ -0,0 +1,64 @@
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import HWC3, nms, resize_image_with_pad, safe_step,common_input_validate, custom_hf_download, HF_MODEL_NAME
+from .model import pidinet
+
+
+class PidiNetDetector:
+ def __init__(self, netNetwork):
+ self.netNetwork = netNetwork
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="table5_pidinet.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ netNetwork = pidinet()
+ netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})
+ netNetwork.eval()
+
+ return cls(netNetwork)
+
+ def to(self, device):
+ self.netNetwork.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ detected_map = detected_map[:, :, ::-1].copy()
+ with torch.no_grad():
+ image_pidi = torch.from_numpy(detected_map).float().to(self.device)
+ image_pidi = image_pidi / 255.0
+ image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
+ edge = self.netNetwork(image_pidi)[-1]
+ edge = edge.cpu().numpy()
+ if apply_filter:
+ edge = edge > 0.5
+ if safe:
+ edge = safe_step(edge)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = edge[0, 0]
+
+ if scribble:
+ detected_map = nms(detected_map, 127, 3.0)
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
+ detected_map[detected_map > 4] = 255
+ detected_map[detected_map < 255] = 0
+
+ detected_map = HWC3(remove_pad(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/pidi/model.py b/src/custom_controlnet_aux/pidi/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..16595b35a4f75a6d2b0e832e24b6e11706d77326
--- /dev/null
+++ b/src/custom_controlnet_aux/pidi/model.py
@@ -0,0 +1,681 @@
+"""
+Author: Zhuo Su, Wenzhe Liu
+Date: Feb 18, 2021
+"""
+
+import math
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+nets = {
+ 'baseline': {
+ 'layer0': 'cv',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'c-v15': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'a-v15': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'r-v15': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cvvv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'avvv4': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'rvvv4': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cccv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cv',
+ },
+ 'aaav4': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'cv',
+ },
+ 'rrrv4': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ 'c16': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cd',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cd',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cd',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cd',
+ },
+ 'a16': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'ad',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'ad',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'ad',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'ad',
+ },
+ 'r16': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'rd',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'rd',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'rd',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'rd',
+ },
+ 'carv4': {
+ 'layer0': 'cd',
+ 'layer1': 'ad',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'ad',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'ad',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'ad',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ }
+
+def createConvFunc(op_type):
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
+ if op_type == 'cv':
+ return F.conv2d
+
+ if op_type == 'cd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
+ assert padding == dilation, 'padding for cd_conv set wrong'
+
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y - yc
+ return func
+ elif op_type == 'ad':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
+ assert padding == dilation, 'padding for ad_conv set wrong'
+
+ shape = weights.shape
+ weights = weights.view(shape[0], shape[1], -1)
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ elif op_type == 'rd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
+ padding = 2 * dilation
+
+ shape = weights.shape
+ if weights.is_cuda:
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
+ else:
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device)
+ weights = weights.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
+ buffer[:, :, 12] = 0
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ else:
+ print('impossible to be here unless you force that')
+ return None
+
+class Conv2d(nn.Module):
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
+ super(Conv2d, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+ self.pdc = pdc
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input):
+
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class CSAM(nn.Module):
+ """
+ Compact Spatial Attention Module
+ """
+ def __init__(self, channels):
+ super(CSAM, self).__init__()
+
+ mid_channels = 4
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ y = self.relu1(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = self.sigmoid(y)
+
+ return x * y
+
+class CDCM(nn.Module):
+ """
+ Compact Dilation Convolution based Module
+ """
+ def __init__(self, in_channels, out_channels):
+ super(CDCM, self).__init__()
+
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ x = self.relu1(x)
+ x = self.conv1(x)
+ x1 = self.conv2_1(x)
+ x2 = self.conv2_2(x)
+ x3 = self.conv2_3(x)
+ x4 = self.conv2_4(x)
+ return x1 + x2 + x3 + x4
+
+
+class MapReduce(nn.Module):
+ """
+ Reduce feature maps into a single edge map
+ """
+ def __init__(self, channels):
+ super(MapReduce, self).__init__()
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PDCBlock(nn.Module):
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock, self).__init__()
+ self.stride=stride
+
+ self.stride=stride
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PDCBlock_converted(nn.Module):
+ """
+ CPDC, APDC can be converted to vanilla 3x3 convolution
+ RPDC can be converted to vanilla 5x5 convolution
+ """
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock_converted, self).__init__()
+ self.stride=stride
+
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ if pdc == 'rd':
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PiDiNet(nn.Module):
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
+ super(PiDiNet, self).__init__()
+ self.sa = sa
+ if dil is not None:
+ assert isinstance(dil, int), 'dil should be an int'
+ self.dil = dil
+
+ self.fuseplanes = []
+
+ self.inplane = inplane
+ if convert:
+ if pdcs[0] == 'rd':
+ init_kernel_size = 5
+ init_padding = 2
+ else:
+ init_kernel_size = 3
+ init_padding = 1
+ self.init_block = nn.Conv2d(3, self.inplane,
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
+ block_class = PDCBlock_converted
+ else:
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
+ block_class = PDCBlock
+
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 2C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.conv_reduces = nn.ModuleList()
+ if self.sa and self.dil is not None:
+ self.attentions = nn.ModuleList()
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.attentions.append(CSAM(self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ elif self.sa:
+ self.attentions = nn.ModuleList()
+ for i in range(4):
+ self.attentions.append(CSAM(self.fuseplanes[i]))
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+ elif self.dil is not None:
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ else:
+ for i in range(4):
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
+ nn.init.constant_(self.classifier.weight, 0.25)
+ nn.init.constant_(self.classifier.bias, 0)
+
+ # print('initialization done')
+
+ def get_weights(self):
+ conv_weights = []
+ bn_weights = []
+ relu_weights = []
+ for pname, p in self.named_parameters():
+ if 'bn' in pname:
+ bn_weights.append(p)
+ elif 'relu' in pname:
+ relu_weights.append(p)
+ else:
+ conv_weights.append(p)
+
+ return conv_weights, bn_weights, relu_weights
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+
+ x = self.init_block(x)
+
+ x1 = self.block1_1(x)
+ x1 = self.block1_2(x1)
+ x1 = self.block1_3(x1)
+
+ x2 = self.block2_1(x1)
+ x2 = self.block2_2(x2)
+ x2 = self.block2_3(x2)
+ x2 = self.block2_4(x2)
+
+ x3 = self.block3_1(x2)
+ x3 = self.block3_2(x3)
+ x3 = self.block3_3(x3)
+ x3 = self.block3_4(x3)
+
+ x4 = self.block4_1(x3)
+ x4 = self.block4_2(x4)
+ x4 = self.block4_3(x4)
+ x4 = self.block4_4(x4)
+
+ x_fuses = []
+ if self.sa and self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
+ elif self.sa:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](xi))
+ elif self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.dilations[i](xi))
+ else:
+ x_fuses = [x1, x2, x3, x4]
+
+ e1 = self.conv_reduces[0](x_fuses[0])
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
+
+ e2 = self.conv_reduces[1](x_fuses[1])
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
+
+ e3 = self.conv_reduces[2](x_fuses[2])
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
+
+ e4 = self.conv_reduces[3](x_fuses[3])
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
+
+ outputs = [e1, e2, e3, e4]
+
+ output = self.classifier(torch.cat(outputs, dim=1))
+ #if not self.training:
+ # return torch.sigmoid(output)
+
+ outputs.append(output)
+ outputs = [torch.sigmoid(r) for r in outputs]
+ return outputs
+
+def config_model(model):
+ model_options = list(nets.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ # print(str(nets[model]))
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = nets[model][layer_name]
+ pdcs.append(createConvFunc(op))
+
+ return pdcs
+
+def pidinet():
+ pdcs = config_model('carv4')
+ dil = 24 #if args.dil else None
+ return PiDiNet(60, pdcs, dil=dil, sa=True)
+
+
+if __name__ == '__main__':
+ model = pidinet()
+ ckp = torch.load('table5_pidinet.pth')['state_dict']
+ model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
+ im = cv2.imread('examples/test_my/cat_v4.png')
+ im = img2tensor(im).unsqueeze(0)/255.
+ res = model(im)[-1]
+ res = res>0.5
+ res = res.float()
+ res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
+ print(res.shape)
+ cv2.imwrite('edge.png', res)
diff --git a/src/custom_controlnet_aux/processor.py b/src/custom_controlnet_aux/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8e087e2179adc7e03b115744eb9ef1e6f747bdb
--- /dev/null
+++ b/src/custom_controlnet_aux/processor.py
@@ -0,0 +1,147 @@
+"""
+This file contains a Processor that can be used to process images with controlnet aux processors
+"""
+import io
+import logging
+from typing import Dict, Optional, Union
+
+from PIL import Image
+
+from custom_controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
+ LeresDetector, LineartAnimeDetector,
+ LineartDetector, MediapipeFaceDetector,
+ MidasDetector, MLSDdetector, NormalBaeDetector,
+ OpenposeDetector, PidiNetDetector, ZoeDetector, TileDetector)
+
+LOGGER = logging.getLogger(__name__)
+
+
+MODELS = {
+ # checkpoint models
+ 'scribble_hed': {'class': HEDdetector, 'checkpoint': True},
+ 'softedge_hed': {'class': HEDdetector, 'checkpoint': True},
+ 'scribble_hedsafe': {'class': HEDdetector, 'checkpoint': True},
+ 'softedge_hedsafe': {'class': HEDdetector, 'checkpoint': True},
+ 'depth_midas': {'class': MidasDetector, 'checkpoint': True},
+ 'mlsd': {'class': MLSDdetector, 'checkpoint': True},
+ 'openpose': {'class': OpenposeDetector, 'checkpoint': True},
+ 'openpose_face': {'class': OpenposeDetector, 'checkpoint': True},
+ 'openpose_faceonly': {'class': OpenposeDetector, 'checkpoint': True},
+ 'openpose_full': {'class': OpenposeDetector, 'checkpoint': True},
+ 'openpose_hand': {'class': OpenposeDetector, 'checkpoint': True},
+ 'scribble_pidinet': {'class': PidiNetDetector, 'checkpoint': True},
+ 'softedge_pidinet': {'class': PidiNetDetector, 'checkpoint': True},
+ 'scribble_pidsafe': {'class': PidiNetDetector, 'checkpoint': True},
+ 'softedge_pidsafe': {'class': PidiNetDetector, 'checkpoint': True},
+ 'normal_bae': {'class': NormalBaeDetector, 'checkpoint': True},
+ 'lineart_coarse': {'class': LineartDetector, 'checkpoint': True},
+ 'lineart_realistic': {'class': LineartDetector, 'checkpoint': True},
+ 'lineart_anime': {'class': LineartAnimeDetector, 'checkpoint': True},
+ 'depth_zoe': {'class': ZoeDetector, 'checkpoint': True},
+ 'depth_leres': {'class': LeresDetector, 'checkpoint': True},
+ 'depth_leres++': {'class': LeresDetector, 'checkpoint': True},
+ # instantiate
+ 'shuffle': {'class': ContentShuffleDetector, 'checkpoint': False},
+ 'mediapipe_face': {'class': MediapipeFaceDetector, 'checkpoint': False},
+ 'canny': {'class': CannyDetector, 'checkpoint': False},
+ 'tile': {'class': TileDetector, 'checkpoint': False},
+}
+
+
+MODEL_PARAMS = {
+ 'scribble_hed': {'scribble': True},
+ 'softedge_hed': {'scribble': False},
+ 'scribble_hedsafe': {'scribble': True, 'safe': True},
+ 'softedge_hedsafe': {'scribble': False, 'safe': True},
+ 'depth_midas': {},
+ 'mlsd': {},
+ 'openpose': {'include_body': True, 'include_hand': False, 'include_face': False},
+ 'openpose_face': {'include_body': True, 'include_hand': False, 'include_face': True},
+ 'openpose_faceonly': {'include_body': False, 'include_hand': False, 'include_face': True},
+ 'openpose_full': {'include_body': True, 'include_hand': True, 'include_face': True},
+ 'openpose_hand': {'include_body': False, 'include_hand': True, 'include_face': False},
+ 'scribble_pidinet': {'safe': False, 'scribble': True},
+ 'softedge_pidinet': {'safe': False, 'scribble': False},
+ 'scribble_pidsafe': {'safe': True, 'scribble': True},
+ 'softedge_pidsafe': {'safe': True, 'scribble': False},
+ 'normal_bae': {},
+ 'lineart_realistic': {'coarse': False},
+ 'lineart_coarse': {'coarse': True},
+ 'lineart_anime': {},
+ 'canny': {},
+ 'shuffle': {},
+ 'depth_zoe': {},
+ 'depth_leres': {'boost': False},
+ 'depth_leres++': {'boost': True},
+ 'mediapipe_face': {},
+ 'tile': {},
+}
+
+CHOICES = f"Choices for the processor are {list(MODELS.keys())}"
+
+
+class Processor:
+ def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None:
+ """Processor that can be used to process images with controlnet aux processors
+
+ Args:
+ processor_id (str): processor name, options are 'hed, midas, mlsd, openpose,
+ pidinet, normalbae, lineart, lineart_coarse, lineart_anime,
+ canny, content_shuffle, zoe, mediapipe_face, tile'
+ params (Optional[Dict]): parameters for the processor
+ """
+ LOGGER.info("Loading %s".format(processor_id))
+
+ if processor_id not in MODELS:
+ raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}")
+
+ self.processor_id = processor_id
+ self.processor = self.load_processor(self.processor_id)
+
+ # load default params
+ self.params = MODEL_PARAMS[self.processor_id]
+ # update with user params
+ if params:
+ self.params.update(params)
+
+ def load_processor(self, processor_id: str) -> 'Processor':
+ """Load controlnet aux processors
+
+ Args:
+ processor_id (str): processor name
+
+ Returns:
+ Processor: controlnet aux processor
+ """
+ processor = MODELS[processor_id]['class']
+
+ # check if the proecssor is a checkpoint model
+ if MODELS[processor_id]['checkpoint']:
+ processor = processor.from_pretrained("lllyasviel/Annotators")
+ else:
+ processor = processor()
+ return processor
+
+ def __call__(self, image: Union[Image.Image, bytes],
+ to_pil: bool = True) -> Union[Image.Image, bytes]:
+ """processes an image with a controlnet aux processor
+
+ Args:
+ image (Union[Image.Image, bytes]): input image in bytes or PIL Image
+ to_pil (bool): whether to return bytes or PIL Image
+
+ Returns:
+ Union[Image.Image, bytes]: processed image in bytes or PIL Image
+ """
+ # check if bytes or PIL Image
+ if isinstance(image, bytes):
+ image = Image.open(io.BytesIO(image)).convert("RGB")
+
+ processed_image = self.processor(image, **self.params)
+
+ if to_pil:
+ return processed_image
+ else:
+ output_bytes = io.BytesIO()
+ processed_image.save(output_bytes, format='JPEG')
+ return output_bytes.getvalue()
diff --git a/src/custom_controlnet_aux/pyracanny/__init__.py b/src/custom_controlnet_aux/pyracanny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45193b856469f8cec4e9b0e9e74cdba1e92d4f7
--- /dev/null
+++ b/src/custom_controlnet_aux/pyracanny/__init__.py
@@ -0,0 +1,74 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
+
+def centered_canny(x: np.ndarray, canny_low_threshold, canny_high_threshold):
+ assert isinstance(x, np.ndarray)
+ assert x.ndim == 2 and x.dtype == np.uint8
+
+ y = cv2.Canny(x, int(canny_low_threshold), int(canny_high_threshold))
+ y = y.astype(np.float32) / 255.0
+ return y
+
+def centered_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
+ assert isinstance(x, np.ndarray)
+ assert x.ndim == 3 and x.shape[2] == 3
+
+ result = [centered_canny(x[..., i], canny_low_threshold, canny_high_threshold) for i in range(3)]
+ result = np.stack(result, axis=2)
+ return result
+
+def pyramid_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
+ assert isinstance(x, np.ndarray)
+ assert x.ndim == 3 and x.shape[2] == 3
+
+ H, W, C = x.shape
+ acc_edge = None
+
+ for k in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
+ Hs, Ws = int(H * k), int(W * k)
+ small = cv2.resize(x, (Ws, Hs), interpolation=cv2.INTER_AREA)
+ edge = centered_canny_color(small, canny_low_threshold, canny_high_threshold)
+ if acc_edge is None:
+ acc_edge = edge
+ else:
+ acc_edge = cv2.resize(acc_edge, (edge.shape[1], edge.shape[0]), interpolation=cv2.INTER_LINEAR)
+ acc_edge = acc_edge * 0.75 + edge * 0.25
+
+ return acc_edge
+
+def norm255(x, low=4, high=96):
+ assert isinstance(x, np.ndarray)
+ assert x.ndim == 2 and x.dtype == np.float32
+
+ v_min = np.percentile(x, low)
+ v_max = np.percentile(x, high)
+
+ x -= v_min
+ x /= v_max - v_min
+
+ return x * 255.0
+
+def canny_pyramid(x, canny_low_threshold, canny_high_threshold):
+ # For some reasons, SAI's Control-lora Canny seems to be trained on canny maps with non-standard resolutions.
+ # Then we use pyramid to use all resolutions to avoid missing any structure in specific resolutions.
+
+ color_canny = pyramid_canny_color(x, canny_low_threshold, canny_high_threshold)
+ result = np.sum(color_canny, axis=2)
+
+ return norm255(result, low=1, high=99).clip(0, 255).astype(np.uint8)
+
+class PyraCannyDetector:
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ detected_map = canny_pyramid(detected_map, low_threshold, high_threshold)
+ detected_map = HWC3(remove_pad(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/recolor/__init__.py b/src/custom_controlnet_aux/recolor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..274ce3a73358450e4fe203cf9f913ef6f88fdb50
--- /dev/null
+++ b/src/custom_controlnet_aux/recolor/__init__.py
@@ -0,0 +1,39 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L639
+def recolor_luminance(img, thr_a=1.0):
+ result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2LAB)
+ result = result[:, :, 0].astype(np.float32) / 255.0
+ result = result ** thr_a
+ result = (result * 255.0).clip(0, 255).astype(np.uint8)
+ result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
+ return result
+
+
+def recolor_intensity(img, thr_a=1.0):
+ result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2HSV)
+ result = result[:, :, 2].astype(np.float32) / 255.0
+ result = result ** thr_a
+ result = (result * 255.0).clip(0, 255).astype(np.uint8)
+ result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
+ return result
+
+recolor_methods = {
+ "luminance": recolor_luminance,
+ "intensity": recolor_intensity
+}
+
+class Recolorizer:
+ def __call__(self, input_image=None, mode="luminance", gamma_correction=1.0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+ assert mode in recolor_methods.keys()
+ detected_map = recolor_methods[mode](input_image, gamma_correction)
+ detected_map = HWC3(remove_pad(detected_map))
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+ return detected_map
diff --git a/src/custom_controlnet_aux/sam/__init__.py b/src/custom_controlnet_aux/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f3cfde1d82bbf019dfe68e363044a9360d8f04
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/__init__.py
@@ -0,0 +1 @@
+from .sam import SamDetector
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/sam/automatic_mask_generator.py b/src/custom_controlnet_aux/sam/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5a8c969207f119feff7087f94e044403acdff00
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/automatic_mask_generator.py
@@ -0,0 +1,372 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling import Sam
+from .predictor import SamPredictor
+from .utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SamAutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: Sam,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ ) -> None:
+ """
+ Using a SAM model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM with a ViT-H backbone.
+
+ Arguments:
+ model (Sam): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = SamPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
diff --git a/src/custom_controlnet_aux/sam/modeling/__init__.py b/src/custom_controlnet_aux/sam/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ed986c5e79ef4261622372a5be725e31906417b
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/sam/modeling/common.py b/src/custom_controlnet_aux/sam/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/src/custom_controlnet_aux/sam/modeling/image_encoder.py b/src/custom_controlnet_aux/sam/modeling/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/image_encoder.py
@@ -0,0 +1,395 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ """
+ super().__init__()
+ self.img_size = img_size
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/src/custom_controlnet_aux/sam/modeling/mask_decoder.py b/src/custom_controlnet_aux/sam/modeling/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d2fdb03d535a91fa725d1ec4e92a7a1f217dfe0
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/mask_decoder.py
@@ -0,0 +1,176 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/src/custom_controlnet_aux/sam/modeling/prompt_encoder.py b/src/custom_controlnet_aux/sam/modeling/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
diff --git a/src/custom_controlnet_aux/sam/modeling/sam.py b/src/custom_controlnet_aux/sam/modeling/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..7621a038122987acba55d0fc90a865968dd3fce1
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/sam.py
@@ -0,0 +1,172 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ def forward(
+ self,
+ batched_input: List[Dict[str, Any]],
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if not known in advance.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input prompts,
+ C is determined by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ if "point_coords" in image_record:
+ points = (image_record["point_coords"], image_record["point_labels"])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get("boxes", None),
+ masks=image_record.get("mask_inputs", None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record["image"].shape[-2:],
+ original_size=image_record["original_size"],
+ )
+ outputs.append(
+ {
+ "masks": masks,
+ "iou_predictions": iou_predictions,
+ "low_res_logits": low_res_masks,
+ }
+ )
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_encoder.img_size - h
+ padw = self.image_encoder.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/sam/modeling/transformer.py b/src/custom_controlnet_aux/sam/modeling/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/modeling/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/src/custom_controlnet_aux/sam/predictor.py b/src/custom_controlnet_aux/sam/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3820fb7de8647e5d6adf229debc498b33caad62
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/predictor.py
@@ -0,0 +1,269 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+from .modeling import Sam
+
+from typing import Optional, Tuple
+
+from .utils.transforms import ResizeLongestSide
+
+
+class SamPredictor:
+ def __init__(
+ self,
+ sam_model: Sam,
+ ) -> None:
+ """
+ Uses SAM to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam): The model to use for mask prediction.
+ """
+ super().__init__()
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+ self.reset_image()
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = "RGB",
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (
+ len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features = self.model.image_encoder(input_image)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks[0].detach().cpu().numpy()
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert self.features is not None, "Features must exist if an image has been set."
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
diff --git a/src/custom_controlnet_aux/sam/sam.py b/src/custom_controlnet_aux/sam/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7d0b5fcdf270593799852044db6f3896e828b4c
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/sam.py
@@ -0,0 +1,171 @@
+"""
+SAM implementation using HuggingFace transformers for PyTorch 2.7 compatibility.
+"""
+import numpy as np
+import torch
+from PIL import Image
+from typing import Union
+
+# Import utilities
+from ..util import HWC3, common_input_validate, resize_image_with_pad
+
+
+class SamDetector:
+
+ def __init__(self, model_name="facebook/sam-vit-base"):
+ from transformers import SamModel, SamProcessor
+
+ self.model_name = model_name
+ self.processor = SamProcessor.from_pretrained(model_name)
+ self.model = SamModel.from_pretrained(model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=None, model_type="vit_t", filename="mobile_sam.pt", subfolder=None):
+ model_mapping = {
+ "vit_t": "facebook/sam-vit-base",
+ "vit_b": "facebook/sam-vit-base",
+ "vit_l": "facebook/sam-vit-large",
+ "vit_h": "facebook/sam-vit-huge"
+ }
+ if filename and isinstance(filename, str):
+ if "mobile_sam" in filename.lower():
+ model_name = "facebook/sam-vit-base"
+ elif "sam_vit_h" in filename.lower():
+ model_name = "facebook/sam-vit-huge"
+ elif "sam_vit_l" in filename.lower():
+ model_name = "facebook/sam-vit-large"
+ elif "sam_vit_b" in filename.lower():
+ model_name = "facebook/sam-vit-base"
+ else:
+ model_name = model_mapping.get(model_type, "facebook/sam-vit-base")
+ else:
+ model_name = model_mapping.get(model_type, "facebook/sam-vit-base")
+
+ return cls(model_name)
+
+ def to(self, device):
+ self.model = self.model.to(device)
+ self.device = device
+ return self
+
+ def generate_automatic_masks(self, input_image):
+ if isinstance(input_image, np.ndarray):
+ pil_image = Image.fromarray(input_image)
+ else:
+ pil_image = input_image
+
+ height, width = pil_image.size[1], pil_image.size[0]
+
+ points_per_side = max(8, min(24, width // 64, height // 64))
+
+ grid_points = []
+ for i in range(points_per_side):
+ for j in range(points_per_side):
+ x = int((j + 0.5) * width / points_per_side)
+ y = int((i + 0.5) * height / points_per_side)
+ x_offset = int((np.random.random() - 0.5) * (width / points_per_side * 0.3))
+ y_offset = int((np.random.random() - 0.5) * (height / points_per_side * 0.3))
+ x = max(5, min(width - 5, x + x_offset))
+ y = max(5, min(height - 5, y + y_offset))
+ grid_points.append([x, y])
+
+ batch_size = 16
+ all_masks = []
+
+ for i in range(0, len(grid_points), batch_size):
+ batch_points = grid_points[i:i + batch_size]
+ input_points = [batch_points]
+
+ inputs = self.processor(
+ images=pil_image,
+ input_points=input_points,
+ return_tensors="pt"
+ ).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ masks = self.processor.post_process_masks(
+ outputs.pred_masks,
+ inputs["original_sizes"],
+ inputs["reshaped_input_sizes"]
+ )[0]
+
+ masks_np = masks.cpu().numpy()
+
+ for j, mask in enumerate(masks_np):
+ mask_2d = mask[0] if len(mask.shape) > 2 else mask
+ area = int(mask_2d.sum())
+
+ if area > 100:
+ cleaned_mask = self._postprocess_mask(mask_2d)
+ cleaned_area = int(cleaned_mask.sum())
+
+ mask_dict = {
+ 'segmentation': cleaned_mask,
+ 'area': cleaned_area,
+ 'stability_score': 0.88,
+ 'point_coords': batch_points[j % len(batch_points)]
+ }
+ all_masks.append(mask_dict)
+
+ return all_masks
+
+ def _postprocess_mask(self, mask, min_region_area=100):
+ from scipy import ndimage
+ from skimage import morphology
+
+ binary_mask = mask.astype(bool)
+
+ filled_mask = ndimage.binary_fill_holes(binary_mask)
+
+ if filled_mask.any():
+ kernel_close = morphology.disk(5)
+ kernel_open = morphology.disk(3)
+
+ smoothed_mask = morphology.binary_closing(filled_mask, kernel_close)
+ smoothed_mask = morphology.binary_opening(smoothed_mask, kernel_open)
+ smoothed_mask = morphology.binary_closing(smoothed_mask, kernel_close)
+ else:
+ smoothed_mask = filled_mask
+
+ return smoothed_mask.astype(mask.dtype)
+
+ def show_anns(self, anns):
+ if len(anns) == 0:
+ return None
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
+
+ h, w = anns[0]['segmentation'].shape
+
+ final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
+
+ for ann in sorted_anns:
+ m = ann['segmentation']
+
+ img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
+ for i in range(3):
+ img[:,:,i] = np.random.randint(255, dtype=np.uint8)
+
+ final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255)))
+
+ return np.array(final_img, dtype=np.uint8)
+
+ def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs) -> Image.Image:
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ masks = self.generate_automatic_masks(input_image)
+
+ map = self.show_anns(masks)
+
+ if map is None:
+ map = np.zeros((input_image.shape[0], input_image.shape[1], 3), dtype=np.uint8)
+
+ detected_map = HWC3(remove_pad(map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/sam/utils/__init__.py b/src/custom_controlnet_aux/sam/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/src/custom_controlnet_aux/sam/utils/amg.py b/src/custom_controlnet_aux/sam/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..be064071ef399fea96c673ad173689656c23534a
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/utils/amg.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/src/custom_controlnet_aux/sam/utils/onnx.py b/src/custom_controlnet_aux/sam/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/utils/onnx.py
@@ -0,0 +1,144 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+ """
+ This model should not be called directly, but is used in ONNX export.
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+ with some functions modified to enable model tracing. Also supports extra
+ options controlling what information. See the ONNX export script for details.
+ """
+
+ def __init__(
+ self,
+ model: Sam,
+ return_single_mask: bool,
+ use_stability_score: bool = False,
+ return_extra_metrics: bool = False,
+ ) -> None:
+ super().__init__()
+ self.mask_decoder = model.mask_decoder
+ self.model = model
+ self.img_size = model.image_encoder.img_size
+ self.return_single_mask = return_single_mask
+ self.use_stability_score = use_stability_score
+ self.stability_score_offset = 1.0
+ self.return_extra_metrics = return_extra_metrics
+
+ @staticmethod
+ def resize_longest_image_size(
+ input_image_size: torch.Tensor, longest_side: int
+ ) -> torch.Tensor:
+ input_image_size = input_image_size.to(torch.float32)
+ scale = longest_side / torch.max(input_image_size)
+ transformed_size = scale * input_image_size
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+ return transformed_size
+
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+ point_coords = point_coords + 0.5
+ point_coords = point_coords / self.img_size
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+ point_embedding = point_embedding * (point_labels != -1)
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+ point_labels == -1
+ )
+
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+ i
+ ].weight * (point_labels == i)
+
+ return point_embedding
+
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+ mask_embedding = mask_embedding + (
+ 1 - has_mask_input
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+ return mask_embedding
+
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+ masks = F.interpolate(
+ masks,
+ size=(self.img_size, self.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
+
+ orig_im_size = orig_im_size.to(torch.int64)
+ h, w = orig_im_size[0], orig_im_size[1]
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+ return masks
+
+ def select_masks(
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Determine if we should return the multiclick mask or not from the number of points.
+ # The reweighting is used to avoid control flow.
+ score_reweight = torch.tensor(
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
+ ).to(iou_preds.device)
+ score = iou_preds + (num_points - 2.5) * score_reweight
+ best_idx = torch.argmax(score, dim=1)
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
+
+ return masks, iou_preds
+
+ @torch.no_grad()
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mask_input: torch.Tensor,
+ has_mask_input: torch.Tensor,
+ orig_im_size: torch.Tensor,
+ ):
+ sparse_embedding = self._embed_points(point_coords, point_labels)
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+ masks, scores = self.model.mask_decoder.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embedding,
+ dense_prompt_embeddings=dense_embedding,
+ )
+
+ if self.use_stability_score:
+ scores = calculate_stability_score(
+ masks, self.model.mask_threshold, self.stability_score_offset
+ )
+
+ if self.return_single_mask:
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+ if self.return_extra_metrics:
+ stability_scores = calculate_stability_score(
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+ )
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+ return upscaled_masks, scores, stability_scores, areas, masks
+
+ return upscaled_masks, scores, masks
diff --git a/src/custom_controlnet_aux/sam/utils/transforms.py b/src/custom_controlnet_aux/sam/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85
--- /dev/null
+++ b/src/custom_controlnet_aux/sam/utils/transforms.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image # type: ignore
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+ """
+ Resizes images to the longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes. Provides methods for
+ transforming both numpy array and batched torch tensors.
+ """
+
+ def __init__(self, target_length: int) -> None:
+ self.target_length = target_length
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array shape Bx4. Requires the original image size
+ in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Expects batched images with shape BxCxHxW and float format. This
+ transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
+ return F.interpolate(
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
+ )
+
+ def apply_coords_torch(
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with shape Bx4. Requires the original image
+ size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
diff --git a/src/custom_controlnet_aux/scribble/__init__.py b/src/custom_controlnet_aux/scribble/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53bb98b5470b4e7d410ab15948a44cdd595fcf9d
--- /dev/null
+++ b/src/custom_controlnet_aux/scribble/__init__.py
@@ -0,0 +1,41 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, HWC3
+
+#Not to be confused with "scribble" from HED. That is "fake scribble" which is more accurate and less picky than this.
+class ScribbleDetector:
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_AREA", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ detected_map = np.zeros_like(input_image, dtype=np.uint8)
+ detected_map[np.min(input_image, axis=2) < 127] = 255
+ detected_map = 255 - detected_map
+
+ detected_map = remove_pad(detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+class ScribbleXDog_Detector:
+ def __call__(self, input_image=None, detect_resolution=512, thr_a=32, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ g1 = cv2.GaussianBlur(input_image.astype(np.float32), (0, 0), 0.5)
+ g2 = cv2.GaussianBlur(input_image.astype(np.float32), (0, 0), 5.0)
+ dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
+ result = np.zeros_like(input_image, dtype=np.uint8)
+ result[2 * (255 - dog) > thr_a] = 255
+ #result = 255 - result
+
+ detected_map = HWC3(remove_pad(result))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/shuffle/__init__.py b/src/custom_controlnet_aux/shuffle/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..278db7973d501161cca19f0364c71c76220f5128
--- /dev/null
+++ b/src/custom_controlnet_aux/shuffle/__init__.py
@@ -0,0 +1,87 @@
+import warnings
+
+import cv2
+import numpy as np
+from PIL import Image
+import random
+
+from custom_controlnet_aux.util import HWC3, common_input_validate, img2mask, make_noise_disk, resize_image_with_pad
+
+
+class ContentShuffleDetector:
+ def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", seed=-1, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ H, W, C = input_image.shape
+ if h is None:
+ h = H
+ if w is None:
+ w = W
+ if f is None:
+ f = 256
+ rng = np.random.default_rng(seed) if seed else None
+ x = make_noise_disk(h, w, 1, f, rng=rng) * float(W - 1)
+ y = make_noise_disk(h, w, 1, f, rng=rng) * float(H - 1)
+ flow = np.concatenate([x, y], axis=2).astype(np.float32)
+ detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR)
+ detected_map = remove_pad(detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+
+class ColorShuffleDetector:
+ def __call__(self, img):
+ H, W, C = img.shape
+ F = np.random.randint(64, 384)
+ A = make_noise_disk(H, W, 3, F)
+ B = make_noise_disk(H, W, 3, F)
+ C = (A + B) / 2.0
+ A = (C + (A - C) * 3.0).clip(0, 1)
+ B = (C + (B - C) * 3.0).clip(0, 1)
+ L = img.astype(np.float32) / 255.0
+ Y = A * L + B * (1 - L)
+ Y -= np.min(Y, axis=(0, 1), keepdims=True)
+ Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5)
+ Y *= 255.0
+ return Y.clip(0, 255).astype(np.uint8)
+
+
+class GrayDetector:
+ def __call__(self, img):
+ eps = 1e-5
+ X = img.astype(np.float32)
+ r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2]
+ kr, kg, kb = [random.random() + eps for _ in range(3)]
+ ks = kr + kg + kb
+ kr /= ks
+ kg /= ks
+ kb /= ks
+ Y = r * kr + g * kg + b * kb
+ Y = np.stack([Y] * 3, axis=2)
+ return Y.clip(0, 255).astype(np.uint8)
+
+
+class DownSampleDetector:
+ def __call__(self, img, level=3, k=16.0):
+ h = img.astype(np.float32)
+ for _ in range(level):
+ h += np.random.normal(loc=0.0, scale=k, size=h.shape)
+ h = cv2.pyrDown(h)
+ for _ in range(level):
+ h = cv2.pyrUp(h)
+ h += np.random.normal(loc=0.0, scale=k, size=h.shape)
+ return h.clip(0, 255).astype(np.uint8)
+
+
+class Image2MaskShuffleDetector:
+ def __init__(self, resolution=(640, 512)):
+ self.H, self.W = resolution
+
+ def __call__(self, img):
+ m = img2mask(img, self.H, self.W)
+ m *= 255.0
+ return m.clip(0, 255).astype(np.uint8)
diff --git a/src/custom_controlnet_aux/teed/Fmish.py b/src/custom_controlnet_aux/teed/Fmish.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c867a272bdaf948d435a46a6aaa70478036994
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/Fmish.py
@@ -0,0 +1,17 @@
+"""
+Script provides functional interface for Mish activation function.
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def mish(input):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+ See additional documentation for mish class.
+ """
+ return input * torch.tanh(F.softplus(input))
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/teed/Fsmish.py b/src/custom_controlnet_aux/teed/Fsmish.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb8c55cad89953f202384eee81173e2b2ae10712
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/Fsmish.py
@@ -0,0 +1,20 @@
+"""
+Script based on:
+Wang, Xueliang, Honge Ren, and Achuan Wang.
+ "Smish: A Novel Activation Function for Deep Learning Methods.
+ " Electronics 11.4 (2022): 540.
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def smish(input):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
+ See additional documentation for mish class.
+ """
+ return input * torch.tanh(torch.log(1+torch.sigmoid(input)))
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/teed/LICENSE.txt b/src/custom_controlnet_aux/teed/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4a99ffdd7372b1bfa44ea302330343cb7370d0e9
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Xavier Soria Poma
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/src/custom_controlnet_aux/teed/Xmish.py b/src/custom_controlnet_aux/teed/Xmish.py
new file mode 100644
index 0000000000000000000000000000000000000000..15e84ed98d165ab6eb4db672dbfdf50ef0953e31
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/Xmish.py
@@ -0,0 +1,43 @@
+"""
+Applies the mish function element-wise:
+mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+# import activation functions
+from .Fmish import mish
+
+
+class Mish(nn.Module):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+ Shape:
+ - Input: (N, *) where * means, any number of additional
+ dimensions
+ - Output: (N, *), same shape as the input
+ Examples:
+ >>> m = Mish()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
+ """
+
+ def __init__(self):
+ """
+ Init method.
+ """
+ super().__init__()
+
+ def forward(self, input):
+ """
+ Forward pass of the function.
+ """
+ if torch.__version__ >= "1.9":
+ return F.mish(input)
+ else:
+ return mish(input)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/teed/Xsmish.py b/src/custom_controlnet_aux/teed/Xsmish.py
new file mode 100644
index 0000000000000000000000000000000000000000..df75bee4d3d3585b1b265435a713ef0186b1e701
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/Xsmish.py
@@ -0,0 +1,43 @@
+"""
+Script based on:
+Wang, Xueliang, Honge Ren, and Achuan Wang.
+ "Smish: A Novel Activation Function for Deep Learning Methods.
+ " Electronics 11.4 (2022): 540.
+smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x)))
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+# import activation functions
+from .Fsmish import smish
+
+
+class Smish(nn.Module):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+ Shape:
+ - Input: (N, *) where * means, any number of additional
+ dimensions
+ - Output: (N, *), same shape as the input
+ Examples:
+ >>> m = Mish()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
+ """
+
+ def __init__(self):
+ """
+ Init method.
+ """
+ super().__init__()
+
+ def forward(self, input):
+ """
+ Forward pass of the function.
+ """
+ return smish(input)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/teed/__init__.py b/src/custom_controlnet_aux/teed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08bb3f695e92de4f3cb4bd01a106af9cea1cbfb
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/__init__.py
@@ -0,0 +1,58 @@
+"""
+Hello, welcome on board,
+"""
+from __future__ import print_function
+
+import os
+import cv2
+import numpy as np
+
+import torch
+
+from .ted import TED # TEED architecture
+from einops import rearrange
+from custom_controlnet_aux.util import safe_step, custom_hf_download, BDS_MODEL_NAME, common_input_validate, resize_image_with_pad, HWC3
+from PIL import Image
+
+
+class TEDDetector:
+ def __init__(self, model):
+ self.model = model
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="7_model.pth", subfolder="Annotators"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder)
+ model = TED()
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
+ model.eval()
+ return cls(model)
+
+ def to(self, device):
+ self.model.to(device)
+ self.device = device
+ return self
+
+
+ def __call__(self, input_image, detect_resolution=512, safe_steps=2, upscale_method="INTER_CUBIC", output_type="pil", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ H, W, _ = input_image.shape
+ with torch.no_grad():
+ image_teed = torch.from_numpy(input_image.copy()).float().to(self.device)
+ image_teed = rearrange(image_teed, 'h w c -> 1 c h w')
+ edges = self.model(image_teed)
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
+ edges = np.stack(edges, axis=2)
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
+ if safe_steps != 0:
+ edge = safe_step(edge, safe_steps)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = remove_pad(HWC3(edge))
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map[..., :3])
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/teed/ted.py b/src/custom_controlnet_aux/teed/ted.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff347d5acf767126cc95022a2c2036c0262db40d
--- /dev/null
+++ b/src/custom_controlnet_aux/teed/ted.py
@@ -0,0 +1,296 @@
+# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
+# with a Slightly modification
+# LDC parameters:
+# 155665
+# TED > 58K
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .Fsmish import smish as Fsmish
+from .Xsmish import Smish
+
+
+def weight_init(m):
+ if isinstance(m, (nn.Conv2d,)):
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
+
+ if m.bias is not None:
+ torch.nn.init.zeros_(m.bias)
+
+ # for fusion layer
+ if isinstance(m, (nn.ConvTranspose2d,)):
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
+ if m.bias is not None:
+ torch.nn.init.zeros_(m.bias)
+
+class CoFusion(nn.Module):
+ # from LDC
+
+ def __init__(self, in_ch, out_ch):
+ super(CoFusion, self).__init__()
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
+ stride=1, padding=1) # before 64
+ self.conv3= nn.Conv2d(32, out_ch, kernel_size=3,
+ stride=1, padding=1)# before 64 instead of 32
+ self.relu = nn.ReLU()
+ self.norm_layer1 = nn.GroupNorm(4, 32) # before 64
+
+ def forward(self, x):
+ # fusecat = torch.cat(x, dim=1)
+ attn = self.relu(self.norm_layer1(self.conv1(x)))
+ attn = F.softmax(self.conv3(attn), dim=1)
+ return ((x * attn).sum(1)).unsqueeze(1)
+
+
+class CoFusion2(nn.Module):
+ # TEDv14-3
+ def __init__(self, in_ch, out_ch):
+ super(CoFusion2, self).__init__()
+ self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3,
+ stride=1, padding=1) # before 64
+ # self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
+ # stride=1, padding=1)# before 64
+ self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3,
+ stride=1, padding=1)# before 64 instead of 32
+ self.smish= Smish()#nn.ReLU(inplace=True)
+
+
+ def forward(self, x):
+ # fusecat = torch.cat(x, dim=1)
+ attn = self.conv1(self.smish(x))
+ attn = self.conv3(self.smish(attn)) # before , )dim=1)
+
+ # return ((fusecat * attn).sum(1)).unsqueeze(1)
+ return ((x * attn).sum(1)).unsqueeze(1)
+
+class DoubleFusion(nn.Module):
+ # TED fusion before the final edge map prediction
+ def __init__(self, in_ch, out_ch):
+ super(DoubleFusion, self).__init__()
+ self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3,
+ stride=1, padding=1, groups=in_ch) # before 64
+ self.PSconv1 = nn.PixelShuffle(1)
+
+ self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3,
+ stride=1, padding=1,groups=24)# before 64 instead of 32
+
+ self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()#
+
+
+ def forward(self, x):
+ # fusecat = torch.cat(x, dim=1)
+ attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352]
+
+ attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352]
+
+ return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res
+
+class _DenseLayer(nn.Sequential):
+ def __init__(self, input_features, out_features):
+ super(_DenseLayer, self).__init__()
+
+ self.add_module('conv1', nn.Conv2d(input_features, out_features,
+ kernel_size=3, stride=1, padding=2, bias=True)),
+ self.add_module('smish1', Smish()),
+ self.add_module('conv2', nn.Conv2d(out_features, out_features,
+ kernel_size=3, stride=1, bias=True))
+ def forward(self, x):
+ x1, x2 = x
+
+ new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()
+
+ return 0.5 * (new_features + x2), x2
+
+
+class _DenseBlock(nn.Sequential):
+ def __init__(self, num_layers, input_features, out_features):
+ super(_DenseBlock, self).__init__()
+ for i in range(num_layers):
+ layer = _DenseLayer(input_features, out_features)
+ self.add_module('denselayer%d' % (i + 1), layer)
+ input_features = out_features
+
+
+class UpConvBlock(nn.Module):
+ def __init__(self, in_features, up_scale):
+ super(UpConvBlock, self).__init__()
+ self.up_factor = 2
+ self.constant_features = 16
+
+ layers = self.make_deconv_layers(in_features, up_scale)
+ assert layers is not None, layers
+ self.features = nn.Sequential(*layers)
+
+ def make_deconv_layers(self, in_features, up_scale):
+ layers = []
+ all_pads=[0,0,1,3,7]
+ for i in range(up_scale):
+ kernel_size = 2 ** up_scale
+ pad = all_pads[up_scale] # kernel_size-1
+ out_features = self.compute_out_features(i, up_scale)
+ layers.append(nn.Conv2d(in_features, out_features, 1))
+ layers.append(Smish())
+ layers.append(nn.ConvTranspose2d(
+ out_features, out_features, kernel_size, stride=2, padding=pad))
+ in_features = out_features
+ return layers
+
+ def compute_out_features(self, idx, up_scale):
+ return 1 if idx == up_scale - 1 else self.constant_features
+
+ def forward(self, x):
+ return self.features(x)
+
+
+class SingleConvBlock(nn.Module):
+ def __init__(self, in_features, out_features, stride, use_ac=False):
+ super(SingleConvBlock, self).__init__()
+ # self.use_bn = use_bs
+ self.use_ac=use_ac
+ self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
+ bias=True)
+ if self.use_ac:
+ self.smish = Smish()
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.use_ac:
+ return self.smish(x)
+ else:
+ return x
+
+class DoubleConvBlock(nn.Module):
+ def __init__(self, in_features, mid_features,
+ out_features=None,
+ stride=1,
+ use_act=True):
+ super(DoubleConvBlock, self).__init__()
+
+ self.use_act = use_act
+ if out_features is None:
+ out_features = mid_features
+ self.conv1 = nn.Conv2d(in_features, mid_features,
+ 3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
+ self.smish= Smish()#nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.smish(x)
+ x = self.conv2(x)
+ if self.use_act:
+ x = self.smish(x)
+ return x
+
+
+class TED(nn.Module):
+ """ Definition of Tiny and Efficient Edge Detector
+ model
+ """
+
+ def __init__(self):
+ super(TED, self).__init__()
+ self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,)
+ self.block_2 = DoubleConvBlock(16, 32, use_act=False)
+ self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # skip1 connection, see fig. 2
+ self.side_1 = SingleConvBlock(16, 32, 2)
+
+ # skip2 connection, see fig. 2
+ self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)
+
+ # USNet
+ self.up_block_1 = UpConvBlock(16, 1)
+ self.up_block_2 = UpConvBlock(32, 1)
+ self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)
+
+ self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion
+
+ self.apply(weight_init)
+
+ def slice(self, tensor, slice_shape):
+ t_shape = tensor.shape
+ img_h, img_w = slice_shape
+ if img_w!=t_shape[-1] or img_h!=t_shape[2]:
+ new_tensor = F.interpolate(
+ tensor, size=(img_h, img_w), mode='bicubic',align_corners=False)
+
+ else:
+ new_tensor=tensor
+ # tensor[..., :height, :width]
+ return new_tensor
+ def resize_input(self,tensor):
+ t_shape = tensor.shape
+ if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
+ img_w= ((t_shape[3]// 8) + 1) * 8
+ img_h = ((t_shape[2] // 8) + 1) * 8
+ new_tensor = F.interpolate(
+ tensor, size=(img_h, img_w), mode='bicubic', align_corners=False)
+ else:
+ new_tensor = tensor
+ return new_tensor
+
+ def crop_bdcn(data1, h, w, crop_h, crop_w):
+ # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
+ _, _, h1, w1 = data1.size()
+ assert (h <= h1 and w <= w1)
+ data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w]
+ return data
+
+
+ def forward(self, x, single_test=False):
+ assert x.ndim == 4, x.shape
+ # supose the image size is 352x352
+
+ # Block 1
+ block_1 = self.block_1(x) # [8,16,176,176]
+ block_1_side = self.side_1(block_1) # 16 [8,32,88,88]
+
+ # Block 2
+ block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
+ block_2_down = self.maxpool(block_2) # [8,32,88,88]
+ block_2_add = block_2_down + block_1_side # [8,32,88,88]
+
+ # Block 3
+ block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection
+ block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]
+
+ # upsampling blocks
+ out_1 = self.up_block_1(block_1)
+ out_2 = self.up_block_2(block_2)
+ out_3 = self.up_block_3(block_3)
+
+ results = [out_1, out_2, out_3]
+
+ # concatenate multiscale outputs
+ block_cat = torch.cat(results, dim=1) # Bx6xHxW
+ block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion
+
+ results.append(block_cat)
+ return results
+
+
+if __name__ == '__main__':
+ batch_size = 8
+ img_height = 352
+ img_width = 352
+
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = "cpu"
+ input = torch.rand(batch_size, 3, img_height, img_width).to(device)
+ # target = torch.rand(batch_size, 1, img_height, img_width).to(device)
+ print(f"input shape: {input.shape}")
+ model = TED().to(device)
+ output = model(input)
+ print(f"output shapes: {[t.shape for t in output]}")
+
+ # for i in range(20000):
+ # print(i)
+ # output = model(input)
+ # loss = nn.MSELoss()(output[-1], target)
+ # loss.backward()
diff --git a/src/custom_controlnet_aux/tile/__init__.py b/src/custom_controlnet_aux/tile/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c58555bcaa928a3044c50e30bf2021b0e3b1926
--- /dev/null
+++ b/src/custom_controlnet_aux/tile/__init__.py
@@ -0,0 +1,82 @@
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import get_upscale_method, common_input_validate, HWC3
+from .guided_filter import FastGuidedFilter
+
+class TileDetector:
+ def __call__(self, input_image=None, pyrUp_iters=3, output_type=None, upscale_method="INTER_AREA", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ H, W, _ = input_image.shape
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ detected_map = cv2.resize(input_image, (W // (2 ** pyrUp_iters), H // (2 ** pyrUp_iters)),
+ interpolation=get_upscale_method(upscale_method))
+ detected_map = HWC3(detected_map)
+
+ for _ in range(pyrUp_iters):
+ detected_map = cv2.pyrUp(detected_map)
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+
+# Source: https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic/blob/main/TTP_tile_preprocessor_v5.py
+
+def apply_gaussian_blur(image_np, ksize=5, sigmaX=1.0):
+ if ksize % 2 == 0:
+ ksize += 1 # ksize must be odd
+ blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX)
+ return blurred_image
+
+def apply_guided_filter(image_np, radius, eps, scale):
+ filter = FastGuidedFilter(image_np, radius, eps, scale)
+ return filter.filter(image_np)
+
+class TTPlanet_Tile_Detector_GF:
+ def __call__(self, input_image, scale_factor, blur_strength, radius, eps, output_type=None, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ img_np = input_image[:, :, ::-1] # RGB to BGR
+
+ # Apply Gaussian blur
+ img_np = apply_gaussian_blur(img_np, ksize=int(blur_strength), sigmaX=blur_strength / 2)
+
+ # Apply Guided Filter
+ img_np = apply_guided_filter(img_np, radius, eps, scale_factor)
+
+ # Resize image
+ height, width = img_np.shape[:2]
+ new_width = int(width / scale_factor)
+ new_height = int(height / scale_factor)
+ resized_down = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_AREA)
+ resized_img = cv2.resize(resized_down, (width, height), interpolation=cv2.INTER_CUBIC)
+ detected_map = HWC3(resized_img[:, :, ::-1]) # BGR to RGB
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+class TTPLanet_Tile_Detector_Simple:
+ def __call__(self, input_image, scale_factor, blur_strength, output_type=None, **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ img_np = input_image[:, :, ::-1] # RGB to BGR
+
+ # Resize image first if you want blur to apply after resizing
+ height, width = img_np.shape[:2]
+ new_width = int(width / scale_factor)
+ new_height = int(height / scale_factor)
+ resized_down = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_AREA)
+ resized_img = cv2.resize(resized_down, (width, height), interpolation=cv2.INTER_LANCZOS4)
+
+ # Apply Gaussian blur after resizing
+ img_np = apply_gaussian_blur(resized_img, ksize=int(blur_strength), sigmaX=blur_strength / 2)
+ detected_map = HWC3(img_np[:, :, ::-1]) # BGR to RGB
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
diff --git a/src/custom_controlnet_aux/tile/guided_filter.py b/src/custom_controlnet_aux/tile/guided_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..88b953e504ccd5cf7fe00efa81207d41f5421091
--- /dev/null
+++ b/src/custom_controlnet_aux/tile/guided_filter.py
@@ -0,0 +1,281 @@
+
+# -*- coding: utf-8 -*-
+## @package guided_filter.core.filters
+#
+# Implementation of guided filter.
+# * GuidedFilter: Original guided filter.
+# * FastGuidedFilter: Fast version of the guided filter.
+# @author tody
+# @date 2015/08/26
+
+import numpy as np
+import cv2
+
+## Convert image into float32 type.
+def to32F(img):
+ if img.dtype == np.float32:
+ return img
+ return (1.0 / 255.0) * np.float32(img)
+
+## Convert image into uint8 type.
+def to8U(img):
+ if img.dtype == np.uint8:
+ return img
+ return np.clip(np.uint8(255.0 * img), 0, 255)
+
+## Return if the input image is gray or not.
+def _isGray(I):
+ return len(I.shape) == 2
+
+
+## Return down sampled image.
+# @param scale (w/s, h/s) image will be created.
+# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
+def _downSample(I, scale=4, shape=None):
+ if shape is not None:
+ h, w = shape
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
+
+ h, w = I.shape[:2]
+ return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
+
+
+## Return up sampled image.
+# @param scale (w*s, h*s) image will be created.
+# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
+def _upSample(I, scale=2, shape=None):
+ if shape is not None:
+ h, w = shape
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
+
+ h, w = I.shape[:2]
+ return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
+
+## Fast guide filter.
+class FastGuidedFilter:
+ ## Constructor.
+ # @param I Input guidance image. Color or gray.
+ # @param radius Radius of Guided Filter.
+ # @param epsilon Regularization term of Guided Filter.
+ # @param scale Down sampled scale.
+ def __init__(self, I, radius=5, epsilon=0.4, scale=4):
+ I_32F = to32F(I)
+ self._I = I_32F
+ h, w = I.shape[:2]
+
+ I_sub = _downSample(I_32F, scale)
+
+ self._I_sub = I_sub
+ radius = int(radius / scale)
+
+ if _isGray(I):
+ self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
+ else:
+ self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
+
+ ## Apply filter for the input image.
+ # @param p Input image for the filtering.
+ def filter(self, p):
+ p_32F = to32F(p)
+ shape_original = p.shape[:2]
+
+ p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
+
+ if _isGray(p_sub):
+ return self._filterGray(p_sub, shape_original)
+
+ cs = p.shape[2]
+ q = np.array(p_32F)
+
+ for ci in range(cs):
+ q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
+ return to8U(q)
+
+ def _filterGray(self, p_sub, shape_original):
+ ab_sub = self._guided_filter._computeCoefficients(p_sub)
+ ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
+ return self._guided_filter._computeOutput(ab, self._I)
+
+
+## Guide filter.
+class GuidedFilter:
+ ## Constructor.
+ # @param I Input guidance image. Color or gray.
+ # @param radius Radius of Guided Filter.
+ # @param epsilon Regularization term of Guided Filter.
+ def __init__(self, I, radius=5, epsilon=0.4):
+ I_32F = to32F(I)
+
+ if _isGray(I):
+ self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
+ else:
+ self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
+
+ ## Apply filter for the input image.
+ # @param p Input image for the filtering.
+ def filter(self, p):
+ return to8U(self._guided_filter.filter(p))
+
+
+## Common parts of guided filter.
+#
+# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
+# Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
+# GuidedFilterCommon.filter computes filtered image for color and gray.
+class GuidedFilterCommon:
+ def __init__(self, guided_filter):
+ self._guided_filter = guided_filter
+
+ ## Apply filter for the input image.
+ # @param p Input image for the filtering.
+ def filter(self, p):
+ p_32F = to32F(p)
+ if _isGray(p_32F):
+ return self._filterGray(p_32F)
+
+ cs = p.shape[2]
+ q = np.array(p_32F)
+
+ for ci in range(cs):
+ q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
+ return q
+
+ def _filterGray(self, p):
+ ab = self._guided_filter._computeCoefficients(p)
+ return self._guided_filter._computeOutput(ab, self._guided_filter._I)
+
+
+## Guided filter for gray guidance image.
+class GuidedFilterGray:
+ # @param I Input gray guidance image.
+ # @param radius Radius of Guided Filter.
+ # @param epsilon Regularization term of Guided Filter.
+ def __init__(self, I, radius=5, epsilon=0.4):
+ self._radius = 2 * radius + 1
+ self._epsilon = epsilon
+ self._I = to32F(I)
+ self._initFilter()
+ self._filter_common = GuidedFilterCommon(self)
+
+ ## Apply filter for the input image.
+ # @param p Input image for the filtering.
+ def filter(self, p):
+ return self._filter_common.filter(p)
+
+ def _initFilter(self):
+ I = self._I
+ r = self._radius
+ self._I_mean = cv2.blur(I, (r, r))
+ I_mean_sq = cv2.blur(I ** 2, (r, r))
+ self._I_var = I_mean_sq - self._I_mean ** 2
+
+ def _computeCoefficients(self, p):
+ r = self._radius
+ p_mean = cv2.blur(p, (r, r))
+ p_cov = p_mean - self._I_mean * p_mean
+ a = p_cov / (self._I_var + self._epsilon)
+ b = p_mean - a * self._I_mean
+ a_mean = cv2.blur(a, (r, r))
+ b_mean = cv2.blur(b, (r, r))
+ return a_mean, b_mean
+
+ def _computeOutput(self, ab, I):
+ a_mean, b_mean = ab
+ return a_mean * I + b_mean
+
+
+## Guided filter for color guidance image.
+class GuidedFilterColor:
+ # @param I Input color guidance image.
+ # @param radius Radius of Guided Filter.
+ # @param epsilon Regularization term of Guided Filter.
+ def __init__(self, I, radius=5, epsilon=0.2):
+ self._radius = 2 * radius + 1
+ self._epsilon = epsilon
+ self._I = to32F(I)
+ self._initFilter()
+ self._filter_common = GuidedFilterCommon(self)
+
+ ## Apply filter for the input image.
+ # @param p Input image for the filtering.
+ def filter(self, p):
+ return self._filter_common.filter(p)
+
+ def _initFilter(self):
+ I = self._I
+ r = self._radius
+ eps = self._epsilon
+
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
+
+ self._Ir_mean = cv2.blur(Ir, (r, r))
+ self._Ig_mean = cv2.blur(Ig, (r, r))
+ self._Ib_mean = cv2.blur(Ib, (r, r))
+
+ Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
+ Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
+ Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
+ Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
+ Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
+ Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
+
+ Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
+ Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
+ Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
+ Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
+ Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
+ Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
+
+ I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
+ Irr_inv /= I_cov
+ Irg_inv /= I_cov
+ Irb_inv /= I_cov
+ Igg_inv /= I_cov
+ Igb_inv /= I_cov
+ Ibb_inv /= I_cov
+
+ self._Irr_inv = Irr_inv
+ self._Irg_inv = Irg_inv
+ self._Irb_inv = Irb_inv
+ self._Igg_inv = Igg_inv
+ self._Igb_inv = Igb_inv
+ self._Ibb_inv = Ibb_inv
+
+ def _computeCoefficients(self, p):
+ r = self._radius
+ I = self._I
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
+
+ p_mean = cv2.blur(p, (r, r))
+
+ Ipr_mean = cv2.blur(Ir * p, (r, r))
+ Ipg_mean = cv2.blur(Ig * p, (r, r))
+ Ipb_mean = cv2.blur(Ib * p, (r, r))
+
+ Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
+ Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
+ Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
+
+ ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
+ ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
+ ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
+ b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
+
+ ar_mean = cv2.blur(ar, (r, r))
+ ag_mean = cv2.blur(ag, (r, r))
+ ab_mean = cv2.blur(ab, (r, r))
+ b_mean = cv2.blur(b, (r, r))
+
+ return ar_mean, ag_mean, ab_mean, b_mean
+
+ def _computeOutput(self, ab, I):
+ ar_mean, ag_mean, ab_mean, b_mean = ab
+
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
+
+ q = (ar_mean * Ir +
+ ag_mean * Ig +
+ ab_mean * Ib +
+ b_mean)
+
+ return q
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/__init__.py b/src/custom_controlnet_aux/uniformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb75770f8a34809975f28c10116b7ccb5a112b1
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/__init__.py
@@ -0,0 +1,68 @@
+import os
+from .inference import init_segmentor, inference_segmentor, show_result_pyplot
+import warnings
+import cv2
+import numpy as np
+from PIL import Image
+from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME
+import torch
+
+from custom_mmpkg.custom_mmseg.core.evaluation import get_palette
+
+config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "upernet_global_small.py")
+
+
+
+class UniformerSegmentor:
+ def __init__(self, netNetwork):
+ self.model = netNetwork
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="upernet_global_small.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+
+ netNetwork = init_segmentor(config_file, model_path, device="cpu")
+ netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()})
+ netNetwork.eval()
+
+ return cls(netNetwork)
+
+ def to(self, device):
+ self.model.to(device)
+ return self
+
+ def _inference(self, img):
+ if next(self.model.parameters()).device.type == 'mps':
+ # adaptive_avg_pool2d can fail on MPS, workaround with CPU
+ import torch.nn.functional
+
+ orig_adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d
+ def cpu_if_exception(input, *args, **kwargs):
+ try:
+ return orig_adaptive_avg_pool2d(input, *args, **kwargs)
+ except:
+ return orig_adaptive_avg_pool2d(input.cpu(), *args, **kwargs).to(input.device)
+
+ try:
+ torch.nn.functional.adaptive_avg_pool2d = cpu_if_exception
+ result = inference_segmentor(self.model, img)
+ finally:
+ torch.nn.functional.adaptive_avg_pool2d = orig_adaptive_avg_pool2d
+ else:
+ result = inference_segmentor(self.model, img)
+
+ res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
+ return res_img
+
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ detected_map = self._inference(input_image)
+ detected_map = remove_pad(HWC3(detected_map))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'ChaseDB1Dataset'
+data_root = 'data/CHASE_DB1'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (960, 999)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 1024)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 1024),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/train',
+ ann_dir='gtFine/train',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py
new file mode 100644
index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py
@@ -0,0 +1,35 @@
+_base_ = './cityscapes.py'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (769, 769)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2049, 1025),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'DRIVEDataset'
+data_root = 'data/DRIVE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (584, 565)
+crop_size = (64, 64)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'HRFDataset'
+data_root = 'data/HRF'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (2336, 3504)
+crop_size = (256, 256)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py
new file mode 100644
index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClassContext',
+ split='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = 'data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/train.txt',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='JPEGImages',
+ ann_dir='SegmentationClass',
+ split='ImageSets/Segmentation/val.txt',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
@@ -0,0 +1,9 @@
+_base_ = './pascal_voc12.py'
+# dataset settings
+data = dict(
+ train=dict(
+ ann_dir=['SegmentationClass', 'SegmentationClassAug'],
+ split=[
+ 'ImageSets/Segmentation/train.txt',
+ 'ImageSets/Segmentation/aug.txt'
+ ]))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'STAREDataset'
+data_root = 'data/STARE'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (605, 700)
+crop_size = (128, 128)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=img_scale,
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type='RepeatDataset',
+ times=40000,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/training',
+ ann_dir='annotations/training',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='images/validation',
+ ann_dir='annotations/validation',
+ pipeline=test_pipeline))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py b/src/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py
@@ -0,0 +1,14 @@
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook', by_epoch=False),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+cudnn_benchmark = True
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ANNHead',
+ in_channels=[1024, 2048],
+ in_index=[2, 3],
+ channels=512,
+ project_channels=256,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='APCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='CCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ recurrence=2,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='CGNet',
+ norm_cfg=norm_cfg,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16)),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=256,
+ in_index=2,
+ channels=256,
+ num_convs=0,
+ concat_input=False,
+ dropout_ratio=0,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=[
+ 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
+ 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
+ 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
+ 10.396974, 10.055647
+ ])),
+ # model training and testing settings
+ train_cfg=dict(sampler=None),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pam_channels=64,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='ASPPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ dilations=(1, 12, 24, 36),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DepthwiseSeparableASPPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dilations=(1, 12, 24, 36),
+ c1_in_channels=256,
+ c1_channels=48,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DMHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ filter_sizes=(1, 3, 5, 7),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='DNLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EMAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=256,
+ ema_channels=512,
+ num_bases=64,
+ num_stages=3,
+ momentum=0.1,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py
@@ -0,0 +1,48 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='EncHead',
+ in_channels=[512, 1024, 2048],
+ in_index=(1, 2, 3),
+ channels=512,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_se_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py
@@ -0,0 +1,57 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='FastSCNN',
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ norm_cfg=norm_cfg,
+ align_corners=False),
+ decode_head=dict(
+ type='DepthwiseSeparableFCNHead',
+ in_channels=128,
+ channels=128,
+ concat_input=False,
+ num_classes=19,
+ in_index=-1,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ auxiliary_head=[
+ dict(
+ type='FCNHead',
+ in_channels=128,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-2,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ dict(
+ type='FCNHead',
+ in_channels=64,
+ channels=32,
+ num_convs=1,
+ num_classes=19,
+ in_index=-3,
+ norm_cfg=norm_cfg,
+ concat_input=False,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py
@@ -0,0 +1,52 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ channels=sum([18, 36, 72, 144]),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py
@@ -0,0 +1,45 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ num_convs=2,
+ concat_input=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
@@ -0,0 +1,51 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='FCNHead',
+ in_channels=64,
+ in_index=4,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py
@@ -0,0 +1,36 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ neck=dict(
+ type='FPN',
+ in_channels=[64, 128, 320, 512],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole')
+)
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='GCHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
@@ -0,0 +1,25 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ type='MobileNetV3',
+ arch='large',
+ out_indices=(1, 3, 16),
+ norm_cfg=norm_cfg),
+ decode_head=dict(
+ type='LRASPPHead',
+ in_channels=(16, 24, 960),
+ in_index=(0, 1, 2),
+ channels=128,
+ input_transform='multiple_select',
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='NLHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ dropout_ratio=0.1,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py
@@ -0,0 +1,68 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://msra/hrnetv2_w18',
+ backbone=dict(
+ type='HRNet',
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(18, 36)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(18, 36, 72)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(18, 36, 72, 144)))),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=[18, 36, 72, 144],
+ channels=sum([18, 36, 72, 144]),
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ kernel_size=1,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=[18, 36, 72, 144],
+ in_index=(0, 1, 2, 3),
+ input_transform='resize_concat',
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=[
+ dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ dict(
+ type='OCRHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ ocr_channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py
@@ -0,0 +1,56 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='CascadeEncoderDecoder',
+ num_stages=2,
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=4),
+ decode_head=[
+ dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=-1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ dict(
+ type='PointHead',
+ in_channels=[256],
+ in_index=[0],
+ channels=256,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ dropout_ratio=-1,
+ num_classes=19,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+ ],
+ # model training and testing settings
+ train_cfg=dict(
+ num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
+ test_cfg=dict(
+ mode='whole',
+ subdivision_steps=2,
+ subdivision_num_points=8196,
+ scale_factor=2))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py
@@ -0,0 +1,49 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSAHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ mask_size=(97, 97),
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 2, 4),
+ strides=(1, 2, 1, 1),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=2048,
+ in_index=3,
+ channels=512,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UNet',
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False),
+ decode_head=dict(
+ type='PSPHead',
+ in_channels=64,
+ in_index=4,
+ channels=16,
+ pool_scales=(1, 2, 3, 6),
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ channels=64,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=2,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='ResNetV1c',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ dilations=(1, 1, 1, 1),
+ strides=(1, 2, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[256, 512, 1024, 2048],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=1024,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py b/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py
@@ -0,0 +1,43 @@
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[64, 128, 320, 512],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=320,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py
new file mode 100644
index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=160000)
+checkpoint_config = dict(by_epoch=False, interval=16000)
+evaluation = dict(interval=16000, metric='mIoU')
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=20000)
+checkpoint_config = dict(by_epoch=False, interval=2000)
+evaluation = dict(interval=2000, metric='mIoU')
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=40000)
+checkpoint_config = dict(by_epoch=False, interval=4000)
+evaluation = dict(interval=4000, metric='mIoU')
diff --git a/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=80000)
+checkpoint_config = dict(by_epoch=False, interval=8000)
+evaluation = dict(interval=8000, metric='mIoU')
diff --git a/src/custom_controlnet_aux/uniformer/inference.py b/src/custom_controlnet_aux/uniformer/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efc93e16f51e70d80340f76d74c6f3db5a26443
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/inference.py
@@ -0,0 +1,137 @@
+
+import torch
+
+import custom_mmpkg.custom_mmcv as mmcv
+from custom_mmpkg.custom_mmcv.parallel import collate, scatter
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from custom_mmpkg.custom_mmseg.datasets.pipelines import Compose
+from custom_mmpkg.custom_mmseg.models import build_segmentor
+
+def init_segmentor(config, checkpoint=None, device='cuda:0'):
+ """Initialize a segmentor from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+ Use 'cpu' for loading model on CPU.
+ Returns:
+ nn.Module: The constructed segmentor.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ 'but got {}'.format(type(config)))
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ model.PALETTE = checkpoint['meta']['PALETTE']
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+class LoadImage:
+ """A simple pipeline to load image."""
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_segmentor(model, img):
+ """Inference image(s) with the segmentor.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+ images.
+
+ Returns:
+ (list[Tensor]): The segmentation result.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # build the data pipeline
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+ test_pipeline = Compose(test_pipeline)
+ # prepare data
+ data = dict(img=img)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
+
+ data['img'] = [x.to(device) for x in data['img']]
+
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ return result
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ palette=None,
+ fig_size=(15, 10),
+ opacity=0.5,
+ title='',
+ block=True):
+ """Visualize the segmentation results on the image.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (list): The segmentation result.
+ palette (list[list[int]]] | None): The palette of segmentation
+ map. If None is given, random palette will be generated.
+ Default: None
+ fig_size (tuple): Figure size of the pyplot figure.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ title (str): The title of pyplot figure.
+ Default is ''.
+ block (bool): Whether to block the pyplot figure.
+ Default is True.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ img = model.show_result(
+ img, result, palette=palette, show=False, opacity=opacity)
+ # plt.figure(figsize=fig_size)
+ # plt.imshow(mmcv.bgr2rgb(img))
+ # plt.title(title)
+ # plt.tight_layout()
+ # plt.show(block=block)
+ return mmcv.bgr2rgb(img)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py b/src/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+
+from .checkpoint import load_checkpoint
+
+__all__ = ['load_checkpoint']
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py b/src/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..8453fedcd47fafbedd40ca7ed485dce2e23434e0
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py
@@ -0,0 +1,500 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+import custom_mmpkg.custom_mmcv as mmcv
+from custom_mmpkg.custom_mmcv.fileio import FileClient
+from custom_mmpkg.custom_mmcv.fileio import load as load_file
+from custom_mmpkg.custom_mmcv.parallel import is_module_wrapper
+from custom_mmpkg.custom_mmcv.utils import mkdir_or_exist
+from custom_mmpkg.custom_mmcv.runner import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(
+ downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ allowed_backends = ['ceph']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+ if rank == 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict | OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('torchvision://'):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('open-mmlab://'):
+ model_urls = get_external_models()
+ model_name = filename[13:]
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+ f'of open-mmlab://{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(model_url)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ elif filename.startswith('mmcls://'):
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ elif filename.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(filename)
+ elif filename.startswith('pavi://'):
+ model_path = filename[7:]
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+ elif filename.startswith('s3://'):
+ checkpoint = load_fileclient_dist(
+ filename, backend='ceph', map_location=map_location)
+ else:
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+def load_checkpoint(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H*W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1 ** 0.5)
+ S2 = int(L2 ** 0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
+ size=(S2, S2), mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/uniformer.py b/src/custom_controlnet_aux/uniformer/uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33dc1aa318e332187ad925b1f74c3b854db145d5
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/uniformer.py
@@ -0,0 +1,421 @@
+# --------------------------------------------------------
+# UniFormer
+# Copyright (c) 2022 SenseTime X-Lab
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Kunchang Li
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from functools import partial
+from collections import OrderedDict
+from timm.layers import DropPath, to_2tuple, trunc_normal_
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from custom_mmpkg.custom_mmseg.models.builder import BACKBONES
+
+from .mmcv_custom import load_checkpoint
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class CMlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class CBlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.conv1 = nn.Conv2d(dim, dim, 1)
+ self.conv2 = nn.Conv2d(dim, dim, 1)
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SABlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ B, N, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.transpose(1, 2).reshape(B, N, H, W)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class SABlock_Windows(nn.Module):
+ def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.window_size=window_size
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x.permute(0, 2, 3, 1)
+ B, H, W, C = x.shape
+ shortcut = x
+ x = self.norm1(x)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.norm = nn.LayerNorm(embed_dim)
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ x = self.proj(x)
+ B, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ return x
+
+
+@BACKBONES.register_module()
+class UniFormer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512],
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0],
+ windows=False, hybrid=False, window_size=14):
+ """
+ Args:
+ layer (list): number of block in each layer
+ img_size (int, tuple): input image size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ head_dim (int): dimension of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer (nn.Module): normalization layer
+ pretrained_path (str): path of pretrained model
+ use_checkpoint (bool): whether use checkpoint
+ checkpoint_num (list): index for using checkpoint in every stage
+ windows (bool): whether use window MHRA
+ hybrid (bool): whether use hybrid MHRA
+ window_size (int): size of window (>14)
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.checkpoint_num = checkpoint_num
+ self.windows = windows
+ print(f'Use Checkpoint: {self.use_checkpoint}')
+ print(f'Checkpoint Number: {self.checkpoint_num}')
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
+ self.patch_embed2 = PatchEmbed(
+ img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
+ self.patch_embed3 = PatchEmbed(
+ img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
+ self.patch_embed4 = PatchEmbed(
+ img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule
+ num_heads = [dim // head_dim for dim in embed_dim]
+ self.blocks1 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(layers[0])])
+ self.norm1=norm_layer(embed_dim[0])
+ self.blocks2 = nn.ModuleList([
+ CBlock(
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer)
+ for i in range(layers[1])])
+ self.norm2 = norm_layer(embed_dim[1])
+ if self.windows:
+ print('Use local window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ elif hybrid:
+ print('Use hybrid window for blocks in stage3')
+ block3 = []
+ for i in range(layers[2]):
+ if (i + 1) % 4 == 0:
+ block3.append(SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ else:
+ block3.append(SABlock_Windows(
+ dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+ self.blocks3 = nn.ModuleList(block3)
+ else:
+ print('Use global window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+ for i in range(layers[2])])
+ self.norm3 = norm_layer(embed_dim[2])
+ self.blocks4 = nn.ModuleList([
+ SABlock(
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer)
+ for i in range(layers[3])])
+ self.norm4 = norm_layer(embed_dim[3])
+
+ # Representation layer
+ if representation_size:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ self.apply(self._init_weights)
+ self.init_weights(pretrained=pretrained_path)
+
+ def init_weights(self, pretrained):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
+ print(f'Load pretrained model from {pretrained}')
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ out = []
+ x = self.patch_embed1(x)
+ x = self.pos_drop(x)
+ for i, blk in enumerate(self.blocks1):
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm1(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed2(x)
+ for i, blk in enumerate(self.blocks2):
+ if self.use_checkpoint and i < self.checkpoint_num[1]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm2(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed3(x)
+ for i, blk in enumerate(self.blocks3):
+ if self.use_checkpoint and i < self.checkpoint_num[2]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm3(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed4(x)
+ for i, blk in enumerate(self.blocks4):
+ if self.use_checkpoint and i < self.checkpoint_num[3]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm4(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ return tuple(out)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/uniformer/upernet_global_small.py b/src/custom_controlnet_aux/uniformer/upernet_global_small.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1084a1148461ad65e6f206463cf59148f822f0b
--- /dev/null
+++ b/src/custom_controlnet_aux/uniformer/upernet_global_small.py
@@ -0,0 +1,44 @@
+_base_ = [
+ 'configs/_base_/models/upernet_uniformer.py',
+ 'configs/_base_/datasets/ade20k.py',
+ 'configs/_base_/default_runtime.py',
+ 'configs/_base_/schedules/schedule_160k.py'
+]
+
+custom_imports = dict(
+ imports=['custom_controlnet_aux.uniformer.uniformer'],
+ allow_failed_imports=False
+)
+
+model = dict(
+ backbone=dict(
+ type='UniFormer',
+ embed_dim=[64, 128, 320, 512],
+ layers=[3, 4, 8, 3],
+ head_dim=64,
+ drop_path_rate=0.25,
+ windows=False,
+ hybrid=False
+ ),
+ decode_head=dict(
+ in_channels=[64, 128, 320, 512],
+ num_classes=150
+ ),
+ auxiliary_head=dict(
+ in_channels=320,
+ num_classes=150
+ ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/unimatch/__init__.py b/src/custom_controlnet_aux/unimatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..add304dc30aaaa85ca722190b27171c4191f5d72
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/__init__.py
@@ -0,0 +1,195 @@
+import os
+import warnings
+
+import cv2
+import numpy as np
+import torch
+from einops import rearrange
+from PIL import Image
+
+from custom_controlnet_aux.util import resize_image_with_pad,common_input_validate, custom_hf_download, UNIMATCH_MODEL_NAME
+from .utils.flow_viz import save_vis_flow_tofile, flow_to_image
+from .unimatch.unimatch import UniMatch
+import torch.nn.functional as F
+from argparse import Namespace
+
+def inference_flow(model,
+ image1, #np array of HWC
+ image2,
+ padding_factor=8,
+ inference_size=None,
+ attn_type='swin',
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ num_reg_refine=1,
+ pred_bidir_flow=False,
+ pred_bwd_flow=False,
+ fwd_bwd_consistency_check=False,
+ device="cpu",
+ **kwargs
+ ):
+ fixed_inference_size = inference_size
+ transpose_img = False
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0).to(device)
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0).to(device)
+
+ # the model is trained with size: width > height
+ if image1.size(-2) > image1.size(-1):
+ image1 = torch.transpose(image1, -2, -1)
+ image2 = torch.transpose(image2, -2, -1)
+ transpose_img = True
+
+ nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor,
+ int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor]
+ # resize to nearest size or specified size
+ inference_size = nearest_size if fixed_inference_size is None else fixed_inference_size
+ assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
+ ori_size = image1.shape[-2:]
+
+ # resize before inference
+ if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
+ image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
+ align_corners=True)
+ image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
+ align_corners=True)
+ if pred_bwd_flow:
+ image1, image2 = image2, image1
+
+ results_dict = model(image1, image2,
+ attn_type=attn_type,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ num_reg_refine=num_reg_refine,
+ task='flow',
+ pred_bidir_flow=pred_bidir_flow,
+ )
+ flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
+
+ # resize back
+ if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
+ flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
+ align_corners=True)
+ flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
+ flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
+
+ if transpose_img:
+ flow_pr = torch.transpose(flow_pr, -2, -1)
+
+ flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+
+ vis_image = flow_to_image(flow)
+
+ # also predict backward flow
+ if pred_bidir_flow:
+ assert flow_pr.size(0) == 2 # [2, H, W, 2]
+ flow_bwd = flow_pr[1].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+ vis_image = flow_to_image(flow_bwd)
+ flow = flow_bwd
+ return flow, vis_image
+
+MODEL_CONFIGS = {
+ "gmflow-scale1": Namespace(
+ num_scales=1,
+ upsample_factor=8,
+
+ attn_type="swin",
+ feature_channels=128,
+ num_head=1,
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+
+ attn_splits_list=[2],
+ corr_radius_list=[-1],
+ prop_radius_list=[-1],
+
+ reg_refine=False,
+ num_reg_refine=1
+ ),
+ "gmflow-scale2": Namespace(
+ num_scales=2,
+ upsample_factor=4,
+ padding_factor=32,
+
+ attn_type="swin",
+ feature_channels=128,
+ num_head=1,
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+
+ attn_splits_list=[2, 8],
+ corr_radius_list=[-1, 4],
+ prop_radius_list=[-1, 1],
+
+ reg_refine=False,
+ num_reg_refine=1
+ ),
+ "gmflow-scale2-regrefine6": Namespace(
+ num_scales=2,
+ upsample_factor=4,
+ padding_factor=32,
+
+ attn_type="swin",
+ feature_channels=128,
+ num_head=1,
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+
+ attn_splits_list=[2, 8],
+ corr_radius_list=[-1, 4],
+ prop_radius_list=[-1, 1],
+
+ reg_refine=True,
+ num_reg_refine=6
+ )
+}
+
+class UnimatchDetector:
+ def __init__(self, unimatch, config_args):
+ self.unimatch = unimatch
+ self.config_args = config_args
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path=UNIMATCH_MODEL_NAME, filename="gmflow-scale2-regrefine6-mixdata.pth"):
+ model_path = custom_hf_download(pretrained_model_or_path, filename)
+ config_args = None
+ for key in list(MODEL_CONFIGS.keys())[::-1]:
+ if key in filename:
+ config_args = MODEL_CONFIGS[key]
+ break
+ assert config_args, f"Couldn't find hardcoded Unimatch config for {filename}"
+
+ model = UniMatch(feature_channels=config_args.feature_channels,
+ num_scales=config_args.num_scales,
+ upsample_factor=config_args.upsample_factor,
+ num_head=config_args.num_head,
+ ffn_dim_expansion=config_args.ffn_dim_expansion,
+ num_transformer_layers=config_args.num_transformer_layers,
+ reg_refine=config_args.reg_refine,
+ task='flow')
+
+ sd = torch.load(model_path, map_location="cpu")
+ model.load_state_dict(sd['model'])
+ return cls(model, config_args)
+
+ def to(self, device):
+ self.unimatch.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, image1, image2, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", pred_bwd_flow=False, pred_bidir_flow=False, **kwargs):
+ assert image1.shape == image2.shape, f"[Unimatch] image1 and image2 must have the same size, got {image1.shape} and {image2.shape}"
+
+ image1, output_type = common_input_validate(image1, output_type, **kwargs)
+ #image1, remove_pad = resize_image_with_pad(image1, detect_resolution, upscale_method)
+ image2, output_type = common_input_validate(image2, output_type, **kwargs)
+ #image2, remove_pad = resize_image_with_pad(image2, detect_resolution, upscale_method)
+ with torch.no_grad():
+ flow, vis_image = inference_flow(self.unimatch, image1, image2, device=self.device, pred_bwd_flow=pred_bwd_flow, pred_bidir_flow=pred_bidir_flow, **vars(self.config_args))
+
+ if output_type == "pil":
+ vis_image = Image.fromarray(vis_image)
+
+ return flow, vis_image
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/__init__.py b/src/custom_controlnet_aux/unimatch/unimatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/attention.py b/src/custom_controlnet_aux/unimatch/unimatch/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a3c878afe541753022ba85c43b5b2e86e4d254
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/attention.py
@@ -0,0 +1,253 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
+
+
+def single_head_full_attention(q, k, v):
+ # q, k, v: [B, L, C]
+ assert q.dim() == k.dim() == v.dim() == 3
+
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
+ out = torch.matmul(attn, v) # [B, L, C]
+
+ return out
+
+
+def single_head_full_attention_1d(q, k, v,
+ h=None,
+ w=None,
+ ):
+ # q, k, v: [B, L, C]
+
+ assert h is not None and w is not None
+ assert q.size(1) == h * w
+
+ b, _, c = q.size()
+
+ q = q.view(b, h, w, c) # [B, H, W, C]
+ k = k.view(b, h, w, c)
+ v = v.view(b, h, w, c)
+
+ scale_factor = c ** 0.5
+
+ scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
+
+ attn = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
+
+ return out
+
+
+def single_head_split_window_attention(q, k, v,
+ num_splits=1,
+ with_shift=False,
+ h=None,
+ w=None,
+ attn_mask=None,
+ ):
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+ # q, k, v: [B, L, C]
+ assert q.dim() == k.dim() == v.dim() == 3
+
+ assert h is not None and w is not None
+ assert q.size(1) == h * w
+
+ b, _, c = q.size()
+
+ b_new = b * num_splits * num_splits
+
+ window_size_h = h // num_splits
+ window_size_w = w // num_splits
+
+ q = q.view(b, h, w, c) # [B, H, W, C]
+ k = k.view(b, h, w, c)
+ v = v.view(b, h, w, c)
+
+ scale_factor = c ** 0.5
+
+ if with_shift:
+ assert attn_mask is not None # compute once
+ shift_size_h = window_size_h // 2
+ shift_size_w = window_size_w // 2
+
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+
+ q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
+
+ scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
+ ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
+
+ if with_shift:
+ scores += attn_mask.repeat(b, 1, 1)
+
+ attn = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
+
+ out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
+ num_splits=num_splits, channel_last=True) # [B, H, W, C]
+
+ # shift back
+ if with_shift:
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
+
+ out = out.view(b, -1, c)
+
+ return out
+
+
+def single_head_split_window_attention_1d(q, k, v,
+ relative_position_bias=None,
+ num_splits=1,
+ with_shift=False,
+ h=None,
+ w=None,
+ attn_mask=None,
+ ):
+ # q, k, v: [B, L, C]
+
+ assert h is not None and w is not None
+ assert q.size(1) == h * w
+
+ b, _, c = q.size()
+
+ b_new = b * num_splits * h
+
+ window_size_w = w // num_splits
+
+ q = q.view(b * h, w, c) # [B*H, W, C]
+ k = k.view(b * h, w, c)
+ v = v.view(b * h, w, c)
+
+ scale_factor = c ** 0.5
+
+ if with_shift:
+ assert attn_mask is not None # compute once
+ shift_size_w = window_size_w // 2
+
+ q = torch.roll(q, shifts=-shift_size_w, dims=1)
+ k = torch.roll(k, shifts=-shift_size_w, dims=1)
+ v = torch.roll(v, shifts=-shift_size_w, dims=1)
+
+ q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
+ k = split_feature_1d(k, num_splits=num_splits)
+ v = split_feature_1d(v, num_splits=num_splits)
+
+ scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
+ ) / scale_factor # [B*H*K, W/K, W/K]
+
+ if with_shift:
+ # attn_mask: [K, W/K, W/K]
+ scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
+
+ attn = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
+
+ out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
+
+ # shift back
+ if with_shift:
+ out = torch.roll(out, shifts=shift_size_w, dims=2)
+
+ out = out.view(b, -1, c)
+
+ return out
+
+
+class SelfAttnPropagation(nn.Module):
+ """
+ flow propagation with self-attention on feature
+ query: feature0, key: feature0, value: flow
+ """
+
+ def __init__(self, in_channels,
+ **kwargs,
+ ):
+ super(SelfAttnPropagation, self).__init__()
+
+ self.q_proj = nn.Linear(in_channels, in_channels)
+ self.k_proj = nn.Linear(in_channels, in_channels)
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feature0, flow,
+ local_window_attn=False,
+ local_window_radius=1,
+ **kwargs,
+ ):
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
+ if local_window_attn:
+ return self.forward_local_window_attn(feature0, flow,
+ local_window_radius=local_window_radius)
+
+ b, c, h, w = feature0.size()
+
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
+
+ # a note: the ``correct'' implementation should be:
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
+ # this problem is observed while cleaning up the code
+ # however, this doesn't affect the performance since the projection is a linear operation,
+ # thus the two projection matrices for key can be merged
+ # so I just leave it as is in order to not re-train all models :)
+ query = self.q_proj(query) # [B, H*W, C]
+ key = self.k_proj(query) # [B, H*W, C]
+
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
+
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
+ prob = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(prob, value) # [B, H*W, 2]
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ return out
+
+ def forward_local_window_attn(self, feature0, flow,
+ local_window_radius=1,
+ ):
+ assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
+ assert local_window_radius > 0
+
+ b, c, h, w = feature0.size()
+
+ value_channel = flow.size(1)
+
+ feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
+ ).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
+
+ kernel_size = 2 * local_window_radius + 1
+
+ feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
+
+ feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
+ padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
+
+ feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
+ 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
+
+ flow_window = F.unfold(flow, kernel_size=kernel_size,
+ padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
+
+ flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
+ 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
+
+ scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
+
+ prob = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
+ ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
+
+ return out
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/backbone.py b/src/custom_controlnet_aux/unimatch/unimatch/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..a30942eca9cad56e75252c3026dca95bf1021df7
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/backbone.py
@@ -0,0 +1,117 @@
+import torch.nn as nn
+
+from .trident_conv import MultiScaleTridentConv
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
+ ):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
+ dilation=dilation, padding=dilation, stride=stride, bias=False)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+ dilation=dilation, padding=dilation, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.norm1 = norm_layer(planes)
+ self.norm2 = norm_layer(planes)
+ if not stride == 1 or in_planes != planes:
+ self.norm3 = norm_layer(planes)
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class CNNEncoder(nn.Module):
+ def __init__(self, output_dim=128,
+ norm_layer=nn.InstanceNorm2d,
+ num_output_scales=1,
+ **kwargs,
+ ):
+ super(CNNEncoder, self).__init__()
+ self.num_branch = num_output_scales
+
+ feature_dims = [64, 96, 128]
+
+ self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
+ self.norm1 = norm_layer(feature_dims[0])
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = feature_dims[0]
+ self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
+ self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
+
+ # highest resolution 1/4 or 1/8
+ stride = 2 if num_output_scales == 1 else 1
+ self.layer3 = self._make_layer(feature_dims[2], stride=stride,
+ norm_layer=norm_layer,
+ ) # 1/4 or 1/8
+
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
+
+ if self.num_branch > 1:
+ if self.num_branch == 4:
+ strides = (1, 2, 4, 8)
+ elif self.num_branch == 3:
+ strides = (1, 2, 4)
+ elif self.num_branch == 2:
+ strides = (1, 2)
+ else:
+ raise ValueError
+
+ self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
+ kernel_size=3,
+ strides=strides,
+ paddings=1,
+ num_branch=self.num_branch,
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
+ layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
+ layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
+
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x) # 1/2
+ x = self.layer2(x) # 1/4
+ x = self.layer3(x) # 1/8 or 1/4
+
+ x = self.conv2(x)
+
+ if self.num_branch > 1:
+ out = self.trident_conv([x] * self.num_branch) # high to low res
+ else:
+ out = [x]
+
+ return out
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/geometry.py b/src/custom_controlnet_aux/unimatch/unimatch/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..775a95783aeee66a44e6290525de94909af648df
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/geometry.py
@@ -0,0 +1,195 @@
+import torch
+import torch.nn.functional as F
+
+
+def coords_grid(b, h, w, homogeneous=False, device=None):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
+
+ stacks = [x, y]
+
+ if homogeneous:
+ ones = torch.ones_like(x) # [H, W]
+ stacks.append(ones)
+
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+
+ if device is not None:
+ grid = grid.to(device)
+
+ return grid
+
+
+def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
+ assert device is not None
+
+ x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
+ torch.linspace(h_min, h_max, len_h, device=device)],
+ )
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
+
+ return grid
+
+
+def normalize_coords(coords, h, w):
+ # coords: [B, H, W, 2]
+ c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
+ return (coords - c) / c # [-1, 1]
+
+
+def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
+ # img: [B, C, H, W]
+ # sample_coords: [B, 2, H, W] in image scale
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
+
+ b, _, h, w = sample_coords.shape
+
+ # Normalize to [-1, 1]
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
+
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
+
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
+
+ if return_mask:
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
+
+ return img, mask
+
+ return img
+
+
+def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
+ b, c, h, w = feature.size()
+ assert flow.size(1) == 2
+
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
+
+ return bilinear_sample(feature, grid, padding_mode=padding_mode,
+ return_mask=mask)
+
+
+def forward_backward_consistency_check(fwd_flow, bwd_flow,
+ alpha=0.01,
+ beta=0.5
+ ):
+ # fwd_flow, bwd_flow: [B, 2, H, W]
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
+
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
+
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
+
+ threshold = alpha * flow_mag + beta
+
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
+ bwd_occ = (diff_bwd > threshold).float()
+
+ return fwd_occ, bwd_occ
+
+
+def back_project(depth, intrinsics):
+ # Back project 2D pixel coords to 3D points
+ # depth: [B, H, W]
+ # intrinsics: [B, 3, 3]
+ b, h, w = depth.shape
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
+
+ intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
+
+ points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W]
+
+ return points
+
+
+def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
+ # Transform 3D points from reference camera to target camera
+ # points_ref: [B, 3, H, W]
+ # extrinsics_ref: [B, 4, 4]
+ # extrinsics_tgt: [B, 4, 4]
+ # extrinsics_rel: [B, 4, 4], relative pose transform
+ b, _, h, w = points_ref.shape
+
+ if extrinsics_rel is None:
+ extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
+
+ points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
+ points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W]
+
+ points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
+
+ return points_tgt
+
+
+def reproject(points_tgt, intrinsics, return_mask=False):
+ # reproject to target view
+ # points_tgt: [B, 3, H, W]
+ # intrinsics: [B, 3, 3]
+
+ b, _, h, w = points_tgt.shape
+
+ proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
+
+ X = proj_points[:, 0]
+ Y = proj_points[:, 1]
+ Z = proj_points[:, 2].clamp(min=1e-3)
+
+ pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale
+
+ if return_mask:
+ # valid mask in pixel space
+ mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & (
+ pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W]
+
+ return pixel_coords, mask
+
+ return pixel_coords
+
+
+def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
+ return_mask=False):
+ # Compute reprojection sample coords
+ points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
+ points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
+
+ if return_mask:
+ reproj_coords, mask = reproject(points_tgt, intrinsics,
+ return_mask=return_mask) # [B, 2, H, W] in image scale
+
+ return reproj_coords, mask
+
+ reproj_coords = reproject(points_tgt, intrinsics,
+ return_mask=return_mask) # [B, 2, H, W] in image scale
+
+ return reproj_coords
+
+
+def compute_flow_with_depth_pose(depth_ref, intrinsics,
+ extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
+ return_mask=False):
+ b, h, w = depth_ref.shape
+ coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
+
+ if return_mask:
+ reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
+ extrinsics_rel=extrinsics_rel,
+ return_mask=return_mask) # [B, 2, H, W]
+ rigid_flow = reproj_coords - coords_init
+
+ return rigid_flow, mask
+
+ reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
+ extrinsics_rel=extrinsics_rel,
+ return_mask=return_mask) # [B, 2, H, W]
+
+ rigid_flow = reproj_coords - coords_init
+
+ return rigid_flow
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/matching.py b/src/custom_controlnet_aux/unimatch/unimatch/matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..595437f2307202ab36d7c2ee3dfa0ab44e4dc830
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/matching.py
@@ -0,0 +1,279 @@
+import torch
+import torch.nn.functional as F
+
+from .geometry import coords_grid, generate_window_grid, normalize_coords
+
+
+def global_correlation_softmax(feature0, feature1,
+ pred_bidir_flow=False,
+ ):
+ # global correlation
+ b, c, h, w = feature0.shape
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
+
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
+
+ # flow from softmax
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
+
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
+
+ if pred_bidir_flow:
+ correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
+ b = b * 2
+
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
+
+ correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
+ flow = correspondence - init_grid
+
+ return flow, prob
+
+
+def local_correlation_softmax(feature0, feature1, local_radius,
+ padding_mode='zeros',
+ ):
+ b, c, h, w = feature0.size()
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
+
+ local_h = 2 * local_radius + 1
+ local_w = 2 * local_radius + 1
+
+ window_grid = generate_window_grid(-local_radius, local_radius,
+ -local_radius, local_radius,
+ local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
+
+ sample_coords_softmax = sample_coords
+
+ # exclude coords that are out of image space
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
+
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
+
+ # normalize coordinates to [-1, 1]
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
+ padding_mode=padding_mode, align_corners=True
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
+
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
+
+ # mask invalid locations
+ corr[~valid] = -1e9
+
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
+
+ correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
+ b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ flow = correspondence - coords_init
+ match_prob = prob
+
+ return flow, match_prob
+
+
+def local_correlation_with_flow(feature0, feature1,
+ flow,
+ local_radius,
+ padding_mode='zeros',
+ dilation=1,
+ ):
+ b, c, h, w = feature0.size()
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
+
+ local_h = 2 * local_radius + 1
+ local_w = 2 * local_radius + 1
+
+ window_grid = generate_window_grid(-local_radius, local_radius,
+ -local_radius, local_radius,
+ local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
+ sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
+
+ # flow can be zero when using features after transformer
+ if not isinstance(flow, float):
+ sample_coords = sample_coords + flow.view(
+ b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2]
+ else:
+ assert flow == 0.
+
+ # normalize coordinates to [-1, 1]
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
+ padding_mode=padding_mode, align_corners=True
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
+
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
+
+ corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
+
+ return corr
+
+
+def global_correlation_softmax_stereo(feature0, feature1,
+ ):
+ # global correlation on horizontal direction
+ b, c, h, w = feature0.shape
+
+ x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
+
+ feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
+ feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
+
+ correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W]
+
+ # mask subsequent positions to make disparity positive
+ mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
+ valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
+
+ correlation[~valid_mask] = -1e9
+
+ prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
+
+ correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
+
+ # NOTE: unlike flow, disparity is typically positive
+ disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
+
+ return disparity.unsqueeze(1), prob # feature resolution
+
+
+def local_correlation_softmax_stereo(feature0, feature1, local_radius,
+ ):
+ b, c, h, w = feature0.size()
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
+
+ local_h = 1
+ local_w = 2 * local_radius + 1
+
+ window_grid = generate_window_grid(0, 0,
+ -local_radius, local_radius,
+ local_h, local_w, device=feature0.device) # [1, 2R+1, 2]
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
+
+ sample_coords_softmax = sample_coords
+
+ # exclude coords that are out of image space
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
+
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
+
+ # normalize coordinates to [-1, 1]
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
+ padding_mode='zeros', align_corners=True
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)]
+ feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C]
+
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)]
+
+ # mask invalid locations
+ corr[~valid] = -1e9
+
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
+
+ correspondence = torch.matmul(prob.unsqueeze(-2),
+ sample_coords_softmax).squeeze(-2).view(
+ b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
+
+ flow = correspondence - coords_init # flow at feature resolution
+ match_prob = prob
+
+ flow_x = -flow[:, :1] # [B, 1, H, W]
+
+ return flow_x, match_prob
+
+
+def correlation_softmax_depth(feature0, feature1,
+ intrinsics,
+ pose,
+ depth_candidates,
+ depth_from_argmax=False,
+ pred_bidir_depth=False,
+ ):
+ b, c, h, w = feature0.size()
+ assert depth_candidates.dim() == 4 # [B, D, H, W]
+ scale_factor = c ** 0.5
+
+ if pred_bidir_depth:
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
+ intrinsics = intrinsics.repeat(2, 1, 1)
+ pose = torch.cat((pose, torch.inverse(pose)), dim=0)
+ depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
+
+ # depth candidates are actually inverse depth
+ warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose,
+ 1. / depth_candidates,
+ ) # [B, C, D, H, W]
+
+ correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
+
+ match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
+
+ # for cross-task transfer (flow -> depth), extract depth with argmax at test time
+ if depth_from_argmax:
+ index = torch.argmax(match_prob, dim=1, keepdim=True)
+ depth = torch.gather(depth_candidates, dim=1, index=index)
+ else:
+ depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
+
+ return depth, match_prob
+
+
+def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth,
+ clamp_min_depth=1e-3,
+ ):
+ """
+ feature1: [B, C, H, W]
+ intrinsics: [B, 3, 3]
+ pose: [B, 4, 4]
+ depth: [B, D, H, W]
+ """
+
+ assert intrinsics.size(1) == intrinsics.size(2) == 3
+ assert pose.size(1) == pose.size(2) == 4
+ assert depth.dim() == 4
+
+ b, d, h, w = depth.size()
+ c = feature1.size(1)
+
+ with torch.no_grad():
+ # pixel coordinates
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
+ # back project to 3D and transform viewpoint
+ points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
+ points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(
+ 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W]
+ points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
+ # reproject to 2D image plane
+ points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W]
+ pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
+
+ # normalize to [-1, 1]
+ x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
+ y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
+
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
+
+ # sample features
+ warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W]
+
+ return warped_feature
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/position.py b/src/custom_controlnet_aux/unimatch/unimatch/position.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a6da436c818b7c2784e92dba66f7947d34b7ce
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/position.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
+
+import torch
+import torch.nn as nn
+import math
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x):
+ # x = tensor_list.tensors # [B, C, H, W]
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
+ b, c, h, w = x.size()
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
+ y_embed = mask.cumsum(1, dtype=torch.float32)
+ x_embed = mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/reg_refine.py b/src/custom_controlnet_aux/unimatch/unimatch/reg_refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..47f83da1c5dcd476069e841d045db04998be3604
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/reg_refine.py
@@ -0,0 +1,119 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256,
+ out_dim=2,
+ ):
+ super(FlowHead, self).__init__()
+
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv2(self.relu(self.conv1(x)))
+
+ return out
+
+
+class SepConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192 + 128,
+ kernel_size=5,
+ ):
+ padding = (kernel_size - 1) // 2
+
+ super(SepConvGRU, self).__init__()
+ self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
+ self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
+ self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
+
+ self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
+ self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
+ self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
+
+ def forward(self, h, x):
+ # horizontal
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz1(hx))
+ r = torch.sigmoid(self.convr1(hx))
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
+
+ # vertical
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz2(hx))
+ r = torch.sigmoid(self.convr2(hx))
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
+
+ return h
+
+
+class BasicMotionEncoder(nn.Module):
+ def __init__(self, corr_channels=324,
+ flow_channels=2,
+ ):
+ super(BasicMotionEncoder, self).__init__()
+
+ self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
+ self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3)
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
+ self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ cor = F.relu(self.convc2(cor))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+
+class BasicUpdateBlock(nn.Module):
+ def __init__(self, corr_channels=324,
+ hidden_dim=128,
+ context_dim=128,
+ downsample_factor=8,
+ flow_dim=2,
+ bilinear_up=False,
+ ):
+ super(BasicUpdateBlock, self).__init__()
+
+ self.encoder = BasicMotionEncoder(corr_channels=corr_channels,
+ flow_channels=flow_dim,
+ )
+
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
+
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
+ out_dim=flow_dim,
+ )
+
+ if bilinear_up:
+ self.mask = None
+ else:
+ self.mask = nn.Sequential(
+ nn.Conv2d(hidden_dim, 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
+
+ def forward(self, net, inp, corr, flow):
+ motion_features = self.encoder(flow, corr)
+
+ inp = torch.cat([inp, motion_features], dim=1)
+
+ net = self.gru(net, inp)
+ delta_flow = self.flow_head(net)
+
+ if self.mask is not None:
+ mask = self.mask(net)
+ else:
+ mask = None
+
+ return net, mask, delta_flow
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/transformer.py b/src/custom_controlnet_aux/unimatch/unimatch/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4878e23a64f6609b1bf10740b0a794d8da836c31
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/transformer.py
@@ -0,0 +1,294 @@
+import torch
+import torch.nn as nn
+
+from .attention import (single_head_full_attention, single_head_split_window_attention,
+ single_head_full_attention_1d, single_head_split_window_attention_1d)
+from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
+
+
+class TransformerLayer(nn.Module):
+ def __init__(self,
+ d_model=128,
+ nhead=1,
+ no_ffn=False,
+ ffn_dim_expansion=4,
+ ):
+ super(TransformerLayer, self).__init__()
+
+ self.dim = d_model
+ self.nhead = nhead
+ self.no_ffn = no_ffn
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # no ffn after self-attn, with ffn after cross-attn
+ if not self.no_ffn:
+ in_channels = d_model * 2
+ self.mlp = nn.Sequential(
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
+ nn.GELU(),
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
+ )
+
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, source, target,
+ height=None,
+ width=None,
+ shifted_window_attn_mask=None,
+ shifted_window_attn_mask_1d=None,
+ attn_type='swin',
+ with_shift=False,
+ attn_num_splits=None,
+ ):
+ # source, target: [B, L, C]
+ query, key, value = source, target, target
+
+ # for stereo: 2d attn in self-attn, 1d attn in cross-attn
+ is_self_attn = (query - key).abs().max() < 1e-6
+
+ # single-head attention
+ query = self.q_proj(query) # [B, L, C]
+ key = self.k_proj(key) # [B, L, C]
+ value = self.v_proj(value) # [B, L, C]
+
+ if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d
+ if self.nhead > 1:
+ # we observe that multihead attention slows down the speed and increases the memory consumption
+ # without bringing obvious performance gains and thus the implementation is removed
+ raise NotImplementedError
+ else:
+ message = single_head_split_window_attention(query, key, value,
+ num_splits=attn_num_splits,
+ with_shift=with_shift,
+ h=height,
+ w=width,
+ attn_mask=shifted_window_attn_mask,
+ )
+
+ elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d
+ if self.nhead > 1:
+ raise NotImplementedError
+ else:
+ if is_self_attn:
+ if attn_num_splits > 1:
+ message = single_head_split_window_attention(query, key, value,
+ num_splits=attn_num_splits,
+ with_shift=with_shift,
+ h=height,
+ w=width,
+ attn_mask=shifted_window_attn_mask,
+ )
+ else:
+ # full 2d attn
+ message = single_head_full_attention(query, key, value) # [N, L, C]
+
+ else:
+ # cross attn 1d
+ message = single_head_full_attention_1d(query, key, value,
+ h=height,
+ w=width,
+ )
+
+ elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d
+ if self.nhead > 1:
+ raise NotImplementedError
+ else:
+ if is_self_attn:
+ if attn_num_splits > 1:
+ # self attn shift window
+ message = single_head_split_window_attention(query, key, value,
+ num_splits=attn_num_splits,
+ with_shift=with_shift,
+ h=height,
+ w=width,
+ attn_mask=shifted_window_attn_mask,
+ )
+ else:
+ # full 2d attn
+ message = single_head_full_attention(query, key, value) # [N, L, C]
+ else:
+ if attn_num_splits > 1:
+ assert shifted_window_attn_mask_1d is not None
+ # cross attn 1d shift
+ message = single_head_split_window_attention_1d(query, key, value,
+ num_splits=attn_num_splits,
+ with_shift=with_shift,
+ h=height,
+ w=width,
+ attn_mask=shifted_window_attn_mask_1d,
+ )
+ else:
+ message = single_head_full_attention_1d(query, key, value,
+ h=height,
+ w=width,
+ )
+
+ else:
+ message = single_head_full_attention(query, key, value) # [B, L, C]
+
+ message = self.merge(message) # [B, L, C]
+ message = self.norm1(message)
+
+ if not self.no_ffn:
+ message = self.mlp(torch.cat([source, message], dim=-1))
+ message = self.norm2(message)
+
+ return source + message
+
+
+class TransformerBlock(nn.Module):
+ """self attention + cross attention + FFN"""
+
+ def __init__(self,
+ d_model=128,
+ nhead=1,
+ ffn_dim_expansion=4,
+ ):
+ super(TransformerBlock, self).__init__()
+
+ self.self_attn = TransformerLayer(d_model=d_model,
+ nhead=nhead,
+ no_ffn=True,
+ ffn_dim_expansion=ffn_dim_expansion,
+ )
+
+ self.cross_attn_ffn = TransformerLayer(d_model=d_model,
+ nhead=nhead,
+ ffn_dim_expansion=ffn_dim_expansion,
+ )
+
+ def forward(self, source, target,
+ height=None,
+ width=None,
+ shifted_window_attn_mask=None,
+ shifted_window_attn_mask_1d=None,
+ attn_type='swin',
+ with_shift=False,
+ attn_num_splits=None,
+ ):
+ # source, target: [B, L, C]
+
+ # self attention
+ source = self.self_attn(source, source,
+ height=height,
+ width=width,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ attn_type=attn_type,
+ with_shift=with_shift,
+ attn_num_splits=attn_num_splits,
+ )
+
+ # cross attention and ffn
+ source = self.cross_attn_ffn(source, target,
+ height=height,
+ width=width,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
+ attn_type=attn_type,
+ with_shift=with_shift,
+ attn_num_splits=attn_num_splits,
+ )
+
+ return source
+
+
+class FeatureTransformer(nn.Module):
+ def __init__(self,
+ num_layers=6,
+ d_model=128,
+ nhead=1,
+ ffn_dim_expansion=4,
+ ):
+ super(FeatureTransformer, self).__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ self.layers = nn.ModuleList([
+ TransformerBlock(d_model=d_model,
+ nhead=nhead,
+ ffn_dim_expansion=ffn_dim_expansion,
+ )
+ for i in range(num_layers)])
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feature0, feature1,
+ attn_type='swin',
+ attn_num_splits=None,
+ **kwargs,
+ ):
+
+ b, c, h, w = feature0.shape
+ assert self.d_model == c
+
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
+
+ # 2d attention
+ if 'swin' in attn_type and attn_num_splits > 1:
+ # global and refine use different number of splits
+ window_size_h = h // attn_num_splits
+ window_size_w = w // attn_num_splits
+
+ # compute attn mask once
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
+ input_resolution=(h, w),
+ window_size_h=window_size_h,
+ window_size_w=window_size_w,
+ shift_size_h=window_size_h // 2,
+ shift_size_w=window_size_w // 2,
+ device=feature0.device,
+ ) # [K*K, H/K*W/K, H/K*W/K]
+ else:
+ shifted_window_attn_mask = None
+
+ # 1d attention
+ if 'swin1d' in attn_type and attn_num_splits > 1:
+ window_size_w = w // attn_num_splits
+
+ # compute attn mask once
+ shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
+ input_w=w,
+ window_size_w=window_size_w,
+ shift_size_w=window_size_w // 2,
+ device=feature0.device,
+ ) # [K, W/K, W/K]
+ else:
+ shifted_window_attn_mask_1d = None
+
+ # concat feature0 and feature1 in batch dimension to compute in parallel
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
+
+ for i, layer in enumerate(self.layers):
+ concat0 = layer(concat0, concat1,
+ height=h,
+ width=w,
+ attn_type=attn_type,
+ with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1,
+ attn_num_splits=attn_num_splits,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
+ )
+
+ # update feature1
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
+
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
+
+ # reshape back
+ feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
+ feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
+
+ return feature0, feature1
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/trident_conv.py b/src/custom_controlnet_aux/unimatch/unimatch/trident_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..29a2a73e964a88b68bc095772d9c3cc443e3e0fe
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/trident_conv.py
@@ -0,0 +1,90 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair
+
+
+class MultiScaleTridentConv(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ strides=1,
+ paddings=0,
+ dilations=1,
+ dilation=1,
+ groups=1,
+ num_branch=1,
+ test_branch_idx=-1,
+ bias=False,
+ norm=None,
+ activation=None,
+ ):
+ super(MultiScaleTridentConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.num_branch = num_branch
+ self.stride = _pair(stride)
+ self.groups = groups
+ self.with_bias = bias
+ self.dilation = dilation
+ if isinstance(paddings, int):
+ paddings = [paddings] * self.num_branch
+ if isinstance(dilations, int):
+ dilations = [dilations] * self.num_branch
+ if isinstance(strides, int):
+ strides = [strides] * self.num_branch
+ self.paddings = [_pair(padding) for padding in paddings]
+ self.dilations = [_pair(dilation) for dilation in dilations]
+ self.strides = [_pair(stride) for stride in strides]
+ self.test_branch_idx = test_branch_idx
+ self.norm = norm
+ self.activation = activation
+
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
+ )
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.bias = None
+
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, inputs):
+ num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
+ assert len(inputs) == num_branch
+
+ if self.training or self.test_branch_idx == -1:
+ outputs = [
+ F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
+ ]
+ else:
+ outputs = [
+ F.conv2d(
+ inputs[0],
+ self.weight,
+ self.bias,
+ self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
+ self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
+ self.dilation,
+ self.groups,
+ )
+ ]
+
+ if self.norm is not None:
+ outputs = [self.norm(x) for x in outputs]
+ if self.activation is not None:
+ outputs = [self.activation(x) for x in outputs]
+ return outputs
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/unimatch.py b/src/custom_controlnet_aux/unimatch/unimatch/unimatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c442c67433ac2c42ebe6b4dca6ab7ff4d765e8cb
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/unimatch.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .backbone import CNNEncoder
+from .transformer import FeatureTransformer
+from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow,
+ global_correlation_softmax_stereo, local_correlation_softmax_stereo,
+ correlation_softmax_depth)
+from .attention import SelfAttnPropagation
+from .geometry import flow_warp, compute_flow_with_depth_pose
+from .reg_refine import BasicUpdateBlock
+from .utils import normalize_img, feature_add_position, upsample_flow_with_mask
+
+
+class UniMatch(nn.Module):
+ def __init__(self,
+ num_scales=1,
+ feature_channels=128,
+ upsample_factor=8,
+ num_head=1,
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+ reg_refine=False, # optional local regression refinement
+ task='flow',
+ ):
+ super(UniMatch, self).__init__()
+
+ self.feature_channels = feature_channels
+ self.num_scales = num_scales
+ self.upsample_factor = upsample_factor
+ self.reg_refine = reg_refine
+
+ # CNN
+ self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
+
+ # Transformer
+ self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
+ d_model=feature_channels,
+ nhead=num_head,
+ ffn_dim_expansion=ffn_dim_expansion,
+ )
+
+ # propagation with self-attn
+ self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
+
+ if not self.reg_refine or task == 'depth':
+ # convex upsampling simiar to RAFT
+ # concat feature0 and low res flow as input
+ self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
+ # thus far, all the learnable parameters are task-agnostic
+
+ if reg_refine:
+ # optional task-specific local regression refinement
+ self.refine_proj = nn.Conv2d(128, 256, 1)
+ self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2,
+ downsample_factor=upsample_factor,
+ flow_dim=2 if task == 'flow' else 1,
+ bilinear_up=task == 'depth',
+ )
+
+ def extract_feature(self, img0, img1):
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
+ features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
+
+ # reverse: resolution from low to high
+ features = features[::-1]
+
+ feature0, feature1 = [], []
+
+ for i in range(len(features)):
+ feature = features[i]
+ chunks = torch.chunk(feature, 2, 0) # tuple
+ feature0.append(chunks[0])
+ feature1.append(chunks[1])
+
+ return feature0, feature1
+
+ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
+ is_depth=False):
+ if bilinear:
+ multiplier = 1 if is_depth else upsample_factor
+ up_flow = F.interpolate(flow, scale_factor=upsample_factor,
+ mode='bilinear', align_corners=True) * multiplier
+ else:
+ concat = torch.cat((flow, feature), dim=1)
+ mask = self.upsampler(concat)
+ up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
+ is_depth=is_depth)
+
+ return up_flow
+
+ def forward(self, img0, img1,
+ attn_type=None,
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ num_reg_refine=1,
+ pred_bidir_flow=False,
+ task='flow',
+ intrinsics=None,
+ pose=None, # relative pose transform
+ min_depth=1. / 0.5, # inverse depth range
+ max_depth=1. / 10,
+ num_depth_candidates=64,
+ depth_from_argmax=False,
+ pred_bidir_depth=False,
+ **kwargs,
+ ):
+
+ if pred_bidir_flow:
+ assert task == 'flow'
+
+ if task == 'depth':
+ assert self.num_scales == 1 # multi-scale depth model is not supported yet
+
+ results_dict = {}
+ flow_preds = []
+
+ if task == 'flow':
+ # stereo and depth tasks have normalized img in dataloader
+ img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
+
+ # list of features, resolution low to high
+ feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
+
+ flow = None
+
+ if task != 'depth':
+ assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
+ else:
+ assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1
+
+ for scale_idx in range(self.num_scales):
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
+
+ if pred_bidir_flow and scale_idx > 0:
+ # predicting bidirectional flow with refinement
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
+
+ feature0_ori, feature1_ori = feature0, feature1
+
+ upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
+
+ if task == 'depth':
+ # scale intrinsics
+ intrinsics_curr = intrinsics.clone()
+ intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
+
+ if scale_idx > 0:
+ assert task != 'depth' # not supported for multi-scale depth model
+ flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
+
+ if flow is not None:
+ assert task != 'depth'
+ flow = flow.detach()
+
+ if task == 'stereo':
+ # construct flow vector for disparity
+ # flow here is actually disparity
+ zeros = torch.zeros_like(flow) # [B, 1, H, W]
+ # NOTE: reverse disp, disparity is positive
+ displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
+ feature1 = flow_warp(feature1, displace) # [B, C, H, W]
+ elif task == 'flow':
+ feature1 = flow_warp(feature1, flow) # [B, C, H, W]
+ else:
+ raise NotImplementedError
+
+ attn_splits = attn_splits_list[scale_idx]
+ if task != 'depth':
+ corr_radius = corr_radius_list[scale_idx]
+ prop_radius = prop_radius_list[scale_idx]
+
+ # add position to features
+ feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
+
+ # Transformer
+ feature0, feature1 = self.transformer(feature0, feature1,
+ attn_type=attn_type,
+ attn_num_splits=attn_splits,
+ )
+
+ # correlation and softmax
+ if task == 'depth':
+ # first generate depth candidates
+ b, _, h, w = feature0.size()
+ depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0)
+ depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h,
+ w) # [B, D, H, W]
+
+ flow_pred = correlation_softmax_depth(feature0, feature1,
+ intrinsics_curr,
+ pose,
+ depth_candidates=depth_candidates,
+ depth_from_argmax=depth_from_argmax,
+ pred_bidir_depth=pred_bidir_depth,
+ )[0]
+
+ else:
+ if corr_radius == -1: # global matching
+ if task == 'flow':
+ flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
+ elif task == 'stereo':
+ flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0]
+ else:
+ raise NotImplementedError
+ else: # local matching
+ if task == 'flow':
+ flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
+ elif task == 'stereo':
+ flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0]
+ else:
+ raise NotImplementedError
+
+ # flow or residual flow
+ flow = flow + flow_pred if flow is not None else flow_pred
+
+ if task == 'stereo':
+ flow = flow.clamp(min=0) # positive disparity
+
+ # upsample to the original resolution for supervison at training time only
+ if self.training:
+ flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor,
+ is_depth=task == 'depth')
+ flow_preds.append(flow_bilinear)
+
+ # flow propagation with self-attn
+ if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0:
+ feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
+
+ flow = self.feature_flow_attn(feature0, flow.detach(),
+ local_window_attn=prop_radius > 0,
+ local_window_radius=prop_radius,
+ )
+
+ # bilinear exclude the last one
+ if self.training and scale_idx < self.num_scales - 1:
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
+ upsample_factor=upsample_factor,
+ is_depth=task == 'depth')
+ flow_preds.append(flow_up)
+
+ if scale_idx == self.num_scales - 1:
+ if not self.reg_refine:
+ # upsample to the original image resolution
+
+ if task == 'stereo':
+ flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
+ flow_up_pad = self.upsample_flow(flow_pad, feature0)
+ flow_up = -flow_up_pad[:, :1] # [B, 1, H, W]
+ elif task == 'depth':
+ depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
+ depth_up_pad = self.upsample_flow(depth_pad, feature0,
+ is_depth=True).clamp(min=min_depth, max=max_depth)
+ flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
+ else:
+ flow_up = self.upsample_flow(flow, feature0)
+
+ flow_preds.append(flow_up)
+ else:
+ # task-specific local regression refinement
+ # supervise current flow
+ if self.training:
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
+ upsample_factor=upsample_factor,
+ is_depth=task == 'depth')
+ flow_preds.append(flow_up)
+
+ assert num_reg_refine > 0
+ for refine_iter_idx in range(num_reg_refine):
+ flow = flow.detach()
+
+ if task == 'stereo':
+ zeros = torch.zeros_like(flow) # [B, 1, H, W]
+ # NOTE: reverse disp, disparity is positive
+ displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
+ correlation = local_correlation_with_flow(
+ feature0_ori,
+ feature1_ori,
+ flow=displace,
+ local_radius=4,
+ ) # [B, (2R+1)^2, H, W]
+ elif task == 'depth':
+ if pred_bidir_depth and refine_iter_idx == 0:
+ intrinsics_curr = intrinsics_curr.repeat(2, 1, 1)
+ pose = torch.cat((pose, torch.inverse(pose)), dim=0)
+
+ feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori),
+ dim=0), torch.cat((feature1_ori,
+ feature0_ori), dim=0)
+
+ flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
+ intrinsics_curr,
+ extrinsics_rel=pose,
+ )
+
+ correlation = local_correlation_with_flow(
+ feature0_ori,
+ feature1_ori,
+ flow=flow_from_depth,
+ local_radius=4,
+ ) # [B, (2R+1)^2, H, W]
+
+ else:
+ correlation = local_correlation_with_flow(
+ feature0_ori,
+ feature1_ori,
+ flow=flow,
+ local_radius=4,
+ ) # [B, (2R+1)^2, H, W]
+
+ proj = self.refine_proj(feature0)
+
+ net, inp = torch.chunk(proj, chunks=2, dim=1)
+
+ net = torch.tanh(net)
+ inp = torch.relu(inp)
+
+ net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(),
+ )
+
+ if task == 'depth':
+ flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth)
+ else:
+ flow = flow + residual_flow
+
+ if task == 'stereo':
+ flow = flow.clamp(min=0) # positive
+
+ if self.training or refine_iter_idx == num_reg_refine - 1:
+ if task == 'depth':
+ if refine_iter_idx < num_reg_refine - 1:
+ # bilinear upsampling
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
+ upsample_factor=upsample_factor,
+ is_depth=True)
+ else:
+ # last one convex upsampling
+ # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling
+ # pad depth to 2 channels as flow
+ depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
+ depth_up_pad = self.upsample_flow(depth_pad, feature0,
+ is_depth=True).clamp(min=min_depth,
+ max=max_depth)
+ flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
+
+ else:
+ flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor,
+ is_depth=task == 'depth')
+
+ flow_preds.append(flow_up)
+
+ if task == 'stereo':
+ for i in range(len(flow_preds)):
+ flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W]
+
+ # convert inverse depth to depth
+ if task == 'depth':
+ for i in range(len(flow_preds)):
+ flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W]
+
+ results_dict.update({'flow_preds': flow_preds})
+
+ return results_dict
diff --git a/src/custom_controlnet_aux/unimatch/unimatch/utils.py b/src/custom_controlnet_aux/unimatch/unimatch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a51d3882d082820cf4749ef6e4771b30ca3763b0
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/unimatch/utils.py
@@ -0,0 +1,216 @@
+import torch
+import torch.nn.functional as F
+from .position import PositionEmbeddingSine
+
+
+def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
+ assert device is not None
+
+ x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
+ torch.linspace(h_min, h_max, len_h, device=device)],
+ )
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
+
+ return grid
+
+
+def normalize_coords(coords, h, w):
+ # coords: [B, H, W, 2]
+ c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
+ return (coords - c) / c # [-1, 1]
+
+
+def normalize_img(img0, img1):
+ # loaded images are in [0, 255]
+ # normalize by ImageNet mean and std
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
+ img0 = (img0 / 255. - mean) / std
+ img1 = (img1 / 255. - mean) / std
+
+ return img0, img1
+
+
+def split_feature(feature,
+ num_splits=2,
+ channel_last=False,
+ ):
+ if channel_last: # [B, H, W, C]
+ b, h, w, c = feature.size()
+ assert h % num_splits == 0 and w % num_splits == 0
+
+ b_new = b * num_splits * num_splits
+ h_new = h // num_splits
+ w_new = w // num_splits
+
+ feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
+ ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
+ else: # [B, C, H, W]
+ b, c, h, w = feature.size()
+ assert h % num_splits == 0 and w % num_splits == 0
+
+ b_new = b * num_splits * num_splits
+ h_new = h // num_splits
+ w_new = w // num_splits
+
+ feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
+ ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
+
+ return feature
+
+
+def merge_splits(splits,
+ num_splits=2,
+ channel_last=False,
+ ):
+ if channel_last: # [B*K*K, H/K, W/K, C]
+ b, h, w, c = splits.size()
+ new_b = b // num_splits // num_splits
+
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
+ merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
+ new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
+ else: # [B*K*K, C, H/K, W/K]
+ b, c, h, w = splits.size()
+ new_b = b // num_splits // num_splits
+
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
+ merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
+ new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
+
+ return merge
+
+
+def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
+ shift_size_h, shift_size_w, device=torch.device('cuda')):
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+ # calculate attention mask for SW-MSA
+ h, w = input_resolution
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
+ h_slices = (slice(0, -window_size_h),
+ slice(-window_size_h, -shift_size_h),
+ slice(-shift_size_h, None))
+ w_slices = (slice(0, -window_size_w),
+ slice(-window_size_w, -shift_size_w),
+ slice(-shift_size_w, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
+
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+
+def feature_add_position(feature0, feature1, attn_splits, feature_channels):
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
+
+ if attn_splits > 1: # add position in splited window
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
+
+ position = pos_enc(feature0_splits)
+
+ feature0_splits = feature0_splits + position
+ feature1_splits = feature1_splits + position
+
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
+ else:
+ position = pos_enc(feature0)
+
+ feature0 = feature0 + position
+ feature1 = feature1 + position
+
+ return feature0, feature1
+
+
+def upsample_flow_with_mask(flow, up_mask, upsample_factor,
+ is_depth=False):
+ # convex upsampling following raft
+
+ mask = up_mask
+ b, flow_channel, h, w = flow.shape
+ mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
+ mask = torch.softmax(mask, dim=2)
+
+ multiplier = 1 if is_depth else upsample_factor
+ up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
+
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
+ up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
+ upsample_factor * w) # [B, 2, K*H, K*W]
+
+ return up_flow
+
+
+def split_feature_1d(feature,
+ num_splits=2,
+ ):
+ # feature: [B, W, C]
+ b, w, c = feature.size()
+ assert w % num_splits == 0
+
+ b_new = b * num_splits
+ w_new = w // num_splits
+
+ feature = feature.view(b, num_splits, w // num_splits, c
+ ).view(b_new, w_new, c) # [B*K, W/K, C]
+
+ return feature
+
+
+def merge_splits_1d(splits,
+ h,
+ num_splits=2,
+ ):
+ b, w, c = splits.size()
+ new_b = b // num_splits // h
+
+ splits = splits.view(new_b, h, num_splits, w, c)
+ merge = splits.view(
+ new_b, h, num_splits * w, c) # [B, H, W, C]
+
+ return merge
+
+
+def window_partition_1d(x, window_size_w):
+ """
+ Args:
+ x: (B, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, C)
+ """
+ B, W, C = x.shape
+ x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
+ return x
+
+
+def generate_shift_window_attn_mask_1d(input_w, window_size_w,
+ shift_size_w, device=torch.device('cuda')):
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
+ w_slices = (slice(0, -window_size_w),
+ slice(-window_size_w, -shift_size_w),
+ slice(-shift_size_w, None))
+ cnt = 0
+ for w in w_slices:
+ img_mask[:, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size_w)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
diff --git a/src/custom_controlnet_aux/unimatch/utils/dist_utils.py b/src/custom_controlnet_aux/unimatch/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9958a48fe7c0bb5b33457f132b49d513d8420419
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/dist_utils.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py
+
+import os
+import subprocess
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ # if (TORCH_VERSION != 'parrots'
+ # and digit_version(TORCH_VERSION) < digit_version('1.0')):
+ # initialized = dist._initialized
+ # else:
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+# from DETR repo
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
diff --git a/src/custom_controlnet_aux/unimatch/utils/file_io.py b/src/custom_controlnet_aux/unimatch/utils/file_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f0a99098a2d82dc73987c269795718053880434
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/file_io.py
@@ -0,0 +1,224 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import re
+from PIL import Image
+import sys
+import cv2
+import json
+import os
+
+
+def read_img(filename):
+ # convert to RGB for scene flow finalpass data
+ img = np.array(Image.open(filename).convert('RGB')).astype(np.float32)
+ return img
+
+
+def read_disp(filename, subset=False, vkitti2=False, sintel=False,
+ tartanair=False, instereo2k=False, crestereo=False,
+ fallingthings=False,
+ argoverse=False,
+ raw_disp_png=False,
+ ):
+ # Scene Flow dataset
+ if filename.endswith('pfm'):
+ # For finalpass and cleanpass, gt disparity is positive, subset is negative
+ disp = np.ascontiguousarray(_read_pfm(filename)[0])
+ if subset:
+ disp = -disp
+ # VKITTI2 dataset
+ elif vkitti2:
+ disp = _read_vkitti2_disp(filename)
+ # Sintel
+ elif sintel:
+ disp = _read_sintel_disparity(filename)
+ elif tartanair:
+ disp = _read_tartanair_disp(filename)
+ elif instereo2k:
+ disp = _read_instereo2k_disp(filename)
+ elif crestereo:
+ disp = _read_crestereo_disp(filename)
+ elif fallingthings:
+ disp = _read_fallingthings_disp(filename)
+ elif argoverse:
+ disp = _read_argoverse_disp(filename)
+ elif raw_disp_png:
+ disp = np.array(Image.open(filename)).astype(np.float32)
+ # KITTI
+ elif filename.endswith('png'):
+ disp = _read_kitti_disp(filename)
+ elif filename.endswith('npy'):
+ disp = np.load(filename)
+ else:
+ raise Exception('Invalid disparity file format!')
+ return disp # [H, W]
+
+
+def _read_pfm(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == 'PF':
+ color = True
+ elif header.decode("ascii") == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+
+
+def write_pfm(file, image, scale=1):
+ file = open(file, 'wb')
+
+ color = None
+
+ if image.dtype.name != 'float32':
+ raise Exception('Image dtype must be float32.')
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif len(image.shape) == 2 or len(
+ image.shape) == 3 and image.shape[2] == 1: # greyscale
+ color = False
+ else:
+ raise Exception(
+ 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
+
+ file.write(b'PF\n' if color else b'Pf\n')
+ file.write(b'%d %d\n' % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
+ scale = -scale
+
+ file.write(b'%f\n' % scale)
+
+ image.tofile(file)
+
+
+def _read_kitti_disp(filename):
+ depth = np.array(Image.open(filename))
+ depth = depth.astype(np.float32) / 256.
+ return depth
+
+
+def _read_vkitti2_disp(filename):
+ # read depth
+ depth = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) # in cm
+ depth = (depth / 100).astype(np.float32) # depth clipped to 655.35m for sky
+
+ valid = (depth > 0) & (depth < 655) # depth clipped to 655.35m for sky
+
+ # convert to disparity
+ focal_length = 725.0087 # in pixels
+ baseline = 0.532725 # meter
+
+ disp = baseline * focal_length / depth
+
+ disp[~valid] = 0.000001 # invalid as very small value
+
+ return disp
+
+
+def _read_sintel_disparity(filename):
+ """ Return disparity read from filename. """
+ f_in = np.array(Image.open(filename))
+
+ d_r = f_in[:, :, 0].astype('float32')
+ d_g = f_in[:, :, 1].astype('float32')
+ d_b = f_in[:, :, 2].astype('float32')
+
+ depth = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14)
+ return depth
+
+
+def _read_tartanair_disp(filename):
+ # the infinite distant object such as the sky has a large depth value (e.g. 10000)
+ depth = np.load(filename)
+
+ # change to disparity image
+ disparity = 80.0 / depth
+
+ return disparity
+
+
+def _read_instereo2k_disp(filename):
+ disp = np.array(Image.open(filename))
+ disp = disp.astype(np.float32) / 100.
+ return disp
+
+
+def _read_crestereo_disp(filename):
+ disp = np.array(Image.open(filename))
+ return disp.astype(np.float32) / 32.
+
+
+def _read_fallingthings_disp(filename):
+ depth = np.array(Image.open(filename))
+ camera_file = os.path.join(os.path.dirname(filename), '_camera_settings.json')
+ with open(camera_file, 'r') as f:
+ intrinsics = json.load(f)
+ fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx']
+ disp = (fx * 6.0 * 100) / depth.astype(np.float32)
+
+ return disp
+
+
+def _read_argoverse_disp(filename):
+ disparity_map = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ return np.float32(disparity_map) / 256.
+
+
+def extract_video(video_name):
+ cap = cv2.VideoCapture(video_name)
+ assert cap.isOpened(), f'Failed to load video file {video_name}'
+ # get video info
+ size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
+ int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+
+ print('video size (hxw): %dx%d' % (size[1], size[0]))
+ print('fps: %d' % fps)
+
+ imgs = []
+ while cap.isOpened():
+ # get frames
+ flag, img = cap.read()
+ if not flag:
+ break
+ # to rgb format
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ imgs.append(img)
+
+ return imgs, fps
diff --git a/src/custom_controlnet_aux/unimatch/utils/flow_viz.py b/src/custom_controlnet_aux/unimatch/utils/flow_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbe3f139d8fc54478fc1880eb6aa5a286660540a
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/flow_viz.py
@@ -0,0 +1,290 @@
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image
+
+
+def make_colorwheel():
+ '''
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ '''
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_compute_color(u, v, convert_to_bgr=False):
+ '''
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param u: np.ndarray, input horizontal flow
+ :param v: np.ndarray, input vertical flow
+ :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
+ :return:
+ '''
+
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 1
+ f = fk - k0
+
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range?
+
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+
+ return flow_image
+
+
+def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
+ '''
+ Expects a two dimensional flow image of shape [H,W,2]
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param flow_uv: np.ndarray of shape [H,W,2]
+ :param clip_flow: float, maximum clipping value for flow
+ :return:
+ '''
+
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+
+ return flow_compute_color(u, v, convert_to_bgr)
+
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col + MR, 0] = 255
+
+ return colorwheel
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols + 1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0 - 1] / 255
+ col1 = tmp[k1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
+
+ return img
+
+
+# from https://github.com/gengshan-y/VCN
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ maxu = max(maxu, np.max(u))
+ minu = min(minu, np.min(u))
+
+ maxv = max(maxv, np.max(v))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, np.max(rad))
+
+ u = u / (maxrad + np.finfo(float).eps)
+ v = v / (maxrad + np.finfo(float).eps)
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def save_vis_flow_tofile(flow, output_path):
+ vis_flow = flow_to_image(flow)
+ Image.fromarray(vis_flow).save(output_path)
+
+
+def flow_tensor_to_image(flow):
+ """Used for tensorboard visualization"""
+ flow = flow.permute(1, 2, 0) # [H, W, 2]
+ flow = flow.detach().cpu().numpy()
+ flow = flow_to_image(flow) # [H, W, 3]
+ flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
+
+ return flow
diff --git a/src/custom_controlnet_aux/unimatch/utils/frame_utils.py b/src/custom_controlnet_aux/unimatch/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10d2ee9b3b2832617ddf66225528af59e995c7d
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/frame_utils.py
@@ -0,0 +1,158 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+import cv2
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
+ # Reshape testdata into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data
+
+
+def writeFlow(filename, uv, v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert (uv.ndim == 3)
+ assert (uv.shape[2] == 2)
+ u = uv[:, :, 0]
+ v = uv[:, :, 1]
+ else:
+ u = uv
+
+ assert (u.shape == v.shape)
+ height, width = u.shape
+ f = open(filename, 'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width * nBands))
+ tmp[:, np.arange(width) * 2] = u
+ tmp[:, np.arange(width) * 2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def readFlowKITTI(filename):
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2 ** 15) / 64.0
+ return flow, valid
+
+
+def readDispKITTI(filename):
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
+ valid = disp > 0.0
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
+ return flow, valid
+
+
+def writeFlowKITTI(filename, uv):
+ uv = 64.0 * uv + 2 ** 15
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+ cv2.imwrite(filename, uv[..., ::-1])
+
+
+def read_gen(file_name, pil=False):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ return Image.open(file_name)
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return readFlow(file_name).astype(np.float32)
+ elif ext == '.pfm':
+ flow = readPFM(file_name).astype(np.float32)
+ if len(flow.shape) == 2:
+ return flow
+ else:
+ return flow[:, :, :-1]
+ return []
+
+
+def read_vkitti2_flow(filename):
+ # In R, flow along x-axis normalized by image width and quantized to [0;2^16 – 1]
+ # In G, flow along x-axis normalized by image width and quantized to [0;2^16 – 1]
+ # B = 0 for invalid flow (e.g., sky pixels)
+ bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ h, w, _c = bgr.shape
+ assert bgr.dtype == np.uint16 and _c == 3
+ # b == invalid flow flag == 0 for sky or other invalid flow
+ invalid = bgr[:, :, 0] == 0
+ # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1]
+ out_flow = 2.0 / (2 ** 16 - 1.0) * bgr[:, :, 2:0:-1].astype('f4') - 1 # [H, W, 2]
+ out_flow[..., 0] *= (w - 1)
+ out_flow[..., 1] *= (h - 1)
+
+ out_flow[invalid] = 0.000001 # invalid as very small value to add supervison on the sky
+ valid = (np.logical_or(invalid, ~invalid)).astype(np.float32)
+
+ return out_flow, valid
diff --git a/src/custom_controlnet_aux/unimatch/utils/logger.py b/src/custom_controlnet_aux/unimatch/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..742b23c515932dea67a576cd9169845291f9091b
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/logger.py
@@ -0,0 +1,104 @@
+import torch
+
+from utils.flow_viz import flow_tensor_to_image
+from .visualization import viz_depth_tensor
+
+
+class Logger:
+ def __init__(self, lr_scheduler,
+ summary_writer,
+ summary_freq=100,
+ start_step=0,
+ img_mean=None,
+ img_std=None,
+ ):
+ self.lr_scheduler = lr_scheduler
+ self.total_steps = start_step
+ self.running_loss = {}
+ self.summary_writer = summary_writer
+ self.summary_freq = summary_freq
+
+ self.img_mean = img_mean
+ self.img_std = img_std
+
+ def print_training_status(self, mode='train', is_depth=False):
+ if is_depth:
+ print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq))
+ else:
+ print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq))
+
+ for k in self.running_loss:
+ self.summary_writer.add_scalar(mode + '/' + k,
+ self.running_loss[k] / self.summary_freq, self.total_steps)
+ self.running_loss[k] = 0.0
+
+ def lr_summary(self):
+ lr = self.lr_scheduler.get_last_lr()[0]
+ self.summary_writer.add_scalar('lr', lr, self.total_steps)
+
+ def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train',
+ is_depth=False,
+ ):
+ if self.total_steps % self.summary_freq == 0:
+ if is_depth:
+ img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1]
+ img2 = self.unnormalize_image(img2.detach().cpu())
+
+ concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2]
+
+ self.summary_writer.add_image(mode + '/img', concat, self.total_steps)
+ else:
+ img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1)
+ img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard
+
+ flow_pred = flow_tensor_to_image(flow_preds[-1][0])
+ forward_flow_gt = flow_tensor_to_image(flow_gt[0])
+ flow_concat = torch.cat((torch.from_numpy(flow_pred),
+ torch.from_numpy(forward_flow_gt)), dim=-1)
+
+ concat = torch.cat((img_concat, flow_concat), dim=-2)
+
+ self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps)
+
+ def add_depth_summary(self, depth_pred, depth_gt, mode='train'):
+ # assert depth_pred.dim() == 2 # [H, W]
+ if self.total_steps % self.summary_freq == 0 or 'val' in mode:
+ pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W]
+ gt_viz = viz_depth_tensor(depth_gt.detach().cpu())
+
+ concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2]
+
+ self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps)
+
+ def unnormalize_image(self, img):
+ # img: [3, H, W], used for visualizing image
+ mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img)
+ std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img)
+
+ out = img * std + mean
+
+ return out
+
+ def push(self, metrics, mode='train', is_depth=False, ):
+ self.total_steps += 1
+
+ self.lr_summary()
+
+ for key in metrics:
+ if key not in self.running_loss:
+ self.running_loss[key] = 0.0
+
+ self.running_loss[key] += metrics[key]
+
+ if self.total_steps % self.summary_freq == 0:
+ self.print_training_status(mode, is_depth=is_depth)
+ self.running_loss = {}
+
+ def write_dict(self, results):
+ for key in results:
+ tag = key.split('_')[0]
+ tag = tag + '/' + key
+ self.summary_writer.add_scalar(tag, results[key], self.total_steps)
+
+ def close(self):
+ self.summary_writer.close()
diff --git a/src/custom_controlnet_aux/unimatch/utils/misc.py b/src/custom_controlnet_aux/unimatch/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a3f84496c996287f9f46bdefa98d3697d136e6d
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/misc.py
@@ -0,0 +1,36 @@
+import os
+import sys
+import json
+
+
+def read_text_lines(filepath):
+ with open(filepath, 'r') as f:
+ lines = f.readlines()
+ lines = [l.rstrip() for l in lines]
+ return lines
+
+
+def check_path(path):
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
+
+
+def save_command(save_path, filename='command_train.txt'):
+ check_path(save_path)
+ command = sys.argv
+ save_file = os.path.join(save_path, filename)
+ # Save all training commands when resuming training
+ with open(save_file, 'a') as f:
+ f.write(' '.join(command))
+ f.write('\n\n')
+
+
+def save_args(args, filename='args.json'):
+ args_dict = vars(args)
+ check_path(args.checkpoint_dir)
+ save_path = os.path.join(args.checkpoint_dir, filename)
+
+ # save all training args when resuming training
+ with open(save_path, 'a') as f:
+ json.dump(args_dict, f, indent=4, sort_keys=False)
+ f.write('\n\n')
diff --git a/src/custom_controlnet_aux/unimatch/utils/utils.py b/src/custom_controlnet_aux/unimatch/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d780fa4e16fc1daf4f940791e1f708adf18232
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/utils.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+
+ def __init__(self, dims, mode='sintel', padding_factor=8):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
+ pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
+ if mode == 'sintel':
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
+ else:
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self, x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zeros'):
+ """ Wrapper for grid_sample, uses pixel coordinates """
+ if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2]
+ coords = coords.permute(0, 2, 3, 1)
+
+ H, W = img.shape[-2:]
+ # H = height if height is not None else img.shape[-2]
+ # W = width if width is not None else img.shape[-1]
+
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
+
+ # To handle H or W equals to 1 by explicitly defining height and width
+ if H == 1:
+ assert ygrid.abs().max() < 1e-8
+ H = 10
+ if W == 1:
+ assert xgrid.abs().max() < 1e-8
+ W = 10
+
+ xgrid = 2 * xgrid / (W - 1) - 1
+ ygrid = 2 * ygrid / (H - 1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, mode=mode,
+ padding_mode=padding_mode,
+ align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.squeeze(-1).float()
+
+ return img
+
+
+def coords_grid(batch, ht, wd, normalize=False):
+ if normalize: # [-1, 1]
+ coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1,
+ 2 * torch.arange(wd) / (wd - 1) - 1)
+ else:
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W]
+
+
+def coords_grid_np(h, w): # used for accumulating high speed sintel flow testdata
+ coords = np.meshgrid(np.arange(h, dtype=np.float32),
+ np.arange(w, dtype=np.float32), indexing='ij')
+ coords = np.stack(coords[::-1], axis=-1) # [H, W, 2]
+
+ return coords
+
+
+def compute_out_of_boundary_mask(flow, downsample_factor=None):
+ # flow: [B, 2, H, W]
+ assert flow.dim() == 4 and flow.size(1) == 2
+ b, _, h, w = flow.shape
+ init_coords = coords_grid(b, h, w).to(flow.device)
+ corres = init_coords + flow # [B, 2, H, W]
+
+ if downsample_factor is not None:
+ assert w % downsample_factor == 0 and h % downsample_factor == 0
+ # the actual max disp can predict is in the downsampled feature resolution, then upsample
+ max_w = (w // downsample_factor - 1) * downsample_factor
+ max_h = (h // downsample_factor - 1) * downsample_factor
+ # print('max_w: %d, max_h: %d' % (max_w, max_h))
+ else:
+ max_w = w - 1
+ max_h = h - 1
+
+ valid_mask = (corres[:, 0] >= 0) & (corres[:, 0] <= max_w) & (corres[:, 1] >= 0) & (corres[:, 1] <= max_h)
+
+ # in case very large flow
+ flow_mask = (flow[:, 0].abs() <= max_w) & (flow[:, 1].abs() <= max_h)
+
+ valid_mask = valid_mask & flow_mask
+
+ return valid_mask # [B, H, W]
+
+
+def normalize_coords(grid):
+ """Normalize coordinates of image scale to [-1, 1]
+ Args:
+ grid: [B, 2, H, W]
+ """
+ assert grid.size(1) == 2
+ h, w = grid.size()[2:]
+ grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1]
+ grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1]
+ # grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2]
+ return grid
+
+
+def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
+ b, c, h, w = feature.size()
+ assert flow.size(1) == 2
+
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
+
+ return bilinear_sampler(feature, grid, mask=mask, padding_mode=padding_mode)
+
+
+def upflow8(flow, mode='bilinear'):
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+
+def bilinear_upflow(flow, scale_factor=8):
+ assert flow.size(1) == 2
+ flow = F.interpolate(flow, scale_factor=scale_factor,
+ mode='bilinear', align_corners=True) * scale_factor
+
+ return flow
+
+
+def upsample_flow(flow, img):
+ if flow.size(-1) != img.size(-1):
+ scale_factor = img.size(-1) / flow.size(-1)
+ flow = F.interpolate(flow, size=img.size()[-2:],
+ mode='bilinear', align_corners=True) * scale_factor
+ return flow
+
+
+def count_parameters(model):
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num
+
+
+def set_bn_eval(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm') != -1:
+ m.eval()
diff --git a/src/custom_controlnet_aux/unimatch/utils/visualization.py b/src/custom_controlnet_aux/unimatch/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..30fe8ae6ee884d1e59518f17a7a80019f7433468
--- /dev/null
+++ b/src/custom_controlnet_aux/unimatch/utils/visualization.py
@@ -0,0 +1,107 @@
+import torch
+import torch.utils.data
+import numpy as np
+import torchvision.utils as vutils
+import cv2
+from matplotlib.cm import get_cmap
+import matplotlib as mpl
+import matplotlib.cm as cm
+
+
+def vis_disparity(disp):
+ disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
+ disp_vis = disp_vis.astype("uint8")
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
+
+ return disp_vis
+
+
+def gen_error_colormap():
+ cols = np.array(
+ [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149],
+ [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180],
+ [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209],
+ [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233],
+ [1.5 / 3.0, 3 / 3.0, 224, 243, 248],
+ [3 / 3.0, 6 / 3.0, 254, 224, 144],
+ [6 / 3.0, 12 / 3.0, 253, 174, 97],
+ [12 / 3.0, 24 / 3.0, 244, 109, 67],
+ [24 / 3.0, 48 / 3.0, 215, 48, 39],
+ [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32)
+ cols[:, 2: 5] /= 255.
+ return cols
+
+
+def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1):
+ D_gt_np = D_gt_tensor.detach().cpu().numpy()
+ D_est_np = D_est_tensor.detach().cpu().numpy()
+ B, H, W = D_gt_np.shape
+ # valid mask
+ mask = D_gt_np > 0
+ # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5%
+ error = np.abs(D_gt_np - D_est_np)
+ error[np.logical_not(mask)] = 0
+ error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres)
+ # get colormap
+ cols = gen_error_colormap()
+ # create error image
+ error_image = np.zeros([B, H, W, 3], dtype=np.float32)
+ for i in range(cols.shape[0]):
+ error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:]
+ # TODO: imdilate
+ # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius));
+ error_image[np.logical_not(mask)] = 0.
+ # show color tag in the top-left cornor of the image
+ for i in range(cols.shape[0]):
+ distance = 20
+ error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:]
+
+ return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2])))
+
+
+def save_images(logger, mode_tag, images_dict, global_step):
+ images_dict = tensor2numpy(images_dict)
+ for tag, values in images_dict.items():
+ if not isinstance(values, list) and not isinstance(values, tuple):
+ values = [values]
+ for idx, value in enumerate(values):
+ if len(value.shape) == 3:
+ value = value[:, np.newaxis, :, :]
+ value = value[:1]
+ value = torch.from_numpy(value)
+
+ image_name = '{}/{}'.format(mode_tag, tag)
+ if len(values) > 1:
+ image_name = image_name + "_" + str(idx)
+ logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
+ global_step)
+
+
+def tensor2numpy(var_dict):
+ for key, vars in var_dict.items():
+ if isinstance(vars, np.ndarray):
+ var_dict[key] = vars
+ elif isinstance(vars, torch.Tensor):
+ var_dict[key] = vars.data.cpu().numpy()
+ else:
+ raise NotImplementedError("invalid input type for tensor2numpy")
+
+ return var_dict
+
+
+def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'):
+ # visualize inverse depth
+ assert isinstance(disp, torch.Tensor)
+
+ disp = disp.numpy()
+ vmax = np.percentile(disp, 95)
+ normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax)
+ mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
+ colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3]
+
+ if return_numpy:
+ return colormapped_im
+
+ viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W]
+
+ return viz
diff --git a/src/custom_controlnet_aux/util.py b/src/custom_controlnet_aux/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee98aea453f06dd4cc057900e5d77dda91bd618
--- /dev/null
+++ b/src/custom_controlnet_aux/util.py
@@ -0,0 +1,350 @@
+import os
+import random
+import tempfile
+import warnings
+from contextlib import suppress
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from huggingface_hub import constants, hf_hub_download
+from torch.utils.model_zoo import load_url
+from ast import literal_eval
+
+
+HF_MODEL_NAME = "lllyasviel/Annotators"
+DWPOSE_MODEL_NAME = "yzd-v/DWPose"
+BDS_MODEL_NAME = "bdsqlsz/qinglong_controlnet-lllite"
+DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image"
+MESH_GRAPHORMER_MODEL_NAME = "hr16/ControlNet-HandRefiner-pruned"
+SAM_MODEL_NAME = "dhkim2810/MobileSAM"
+UNIMATCH_MODEL_NAME = "hr16/Unimatch"
+DEPTH_ANYTHING_MODEL_NAME = "LiheYoung/Depth-Anything" #HF Space
+DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge"
+METRIC3D_MODEL_NAME = "JUGGHM/Metric3D"
+
+DEPTH_ANYTHING_V2_MODEL_NAME_DICT = {
+ "depth_anything_v2_vits.pth": "depth-anything/Depth-Anything-V2-Small",
+ "depth_anything_v2_vitb.pth": "depth-anything/Depth-Anything-V2-Base",
+ "depth_anything_v2_vitl.pth": "depth-anything/Depth-Anything-V2-Large",
+ "depth_anything_v2_vitg.pth": "depth-anything/Depth-Anything-V2-Giant",
+ "depth_anything_v2_metric_vkitti_vitl.pth": "depth-anything/Depth-Anything-V2-Metric-VKITTI-Large",
+ "depth_anything_v2_metric_hypersim_vitl.pth": "depth-anything/Depth-Anything-V2-Metric-Hypersim-Large"
+}
+
+temp_dir = tempfile.gettempdir()
+annotator_ckpts_path = os.path.join(Path(__file__).parents[2], 'ckpts')
+USE_SYMLINKS = False
+
+try:
+ annotator_ckpts_path = os.environ['AUX_ANNOTATOR_CKPTS_PATH']
+except:
+ warnings.warn("Custom pressesor model path not set successfully.")
+ pass
+
+try:
+ USE_SYMLINKS = literal_eval(os.environ['AUX_USE_SYMLINKS'])
+except:
+ warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.")
+ pass
+
+try:
+ temp_dir = os.environ['AUX_TEMP_DIR']
+ if len(temp_dir) >= 60:
+ warnings.warn(f"custom temp dir is too long. Using default")
+ temp_dir = tempfile.gettempdir()
+except:
+ warnings.warn(f"custom temp dir not set successfully")
+ pass
+
+here = Path(__file__).parent.resolve()
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def make_noise_disk(H, W, C, F, rng=None):
+ if rng:
+ noise = rng.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
+ else:
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
+ noise = noise[F: F + H, F: F + W]
+ noise -= np.min(noise)
+ noise /= np.max(noise)
+ if C == 1:
+ noise = noise[:, :, None]
+ return noise
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
+
+def min_max_norm(x):
+ x -= np.min(x)
+ x /= np.maximum(np.max(x), 1e-5)
+ return x
+
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+
+def img2mask(img, H, W, low=10, high=90):
+ assert img.ndim == 3 or img.ndim == 2
+ assert img.dtype == np.uint8
+
+ if img.ndim == 3:
+ y = img[:, :, random.randrange(0, img.shape[2])]
+ else:
+ y = img
+
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
+
+ if random.uniform(0, 1) < 0.5:
+ y = 255 - y
+
+ return y < np.percentile(y, random.randrange(low, high))
+
+def safer_memory(x):
+ # Fix many MAC/AMD problems
+ return np.ascontiguousarray(x.copy()).copy()
+
+UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
+def get_upscale_method(method_str):
+ assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
+ return getattr(cv2, method_str)
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17
+#Added upscale_method, mode params
+def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
+ if skip_hwc3:
+ img = input_image
+ else:
+ img = HWC3(input_image)
+ H_raw, W_raw, _ = img.shape
+ if resolution == 0:
+ return img, lambda x: x
+ k = float(resolution) / float(min(H_raw, W_raw))
+ H_target = int(np.round(float(H_raw) * k))
+ W_target = int(np.round(float(W_raw) * k))
+ img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
+
+ def remove_pad(x):
+ return safer_memory(x[:H_target, :W_target, ...])
+
+ return safer_memory(img_padded), remove_pad
+
+def common_input_validate(input_image, output_type, **kwargs):
+ if "img" in kwargs:
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
+ input_image = kwargs.pop("img")
+
+ if "return_pil" in kwargs:
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
+ output_type = "pil" if kwargs["return_pil"] else "np"
+
+ if type(output_type) is bool:
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
+ if output_type:
+ output_type = "pil"
+
+ if input_image is None:
+ raise ValueError("input_image must be defined.")
+
+ if not isinstance(input_image, np.ndarray):
+ input_image = np.array(input_image, dtype=np.uint8)
+ output_type = output_type or "pil"
+ else:
+ output_type = output_type or "np"
+
+ return (input_image, output_type)
+
+def torch_gc():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+def ade_palette():
+ """ADE20K palette that maps each class to RGB values."""
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+#https://stackoverflow.com/a/44873382
+#Assume that the minimum version of Python ppl use is 3.9
+def sha256sum(file_path):
+ import hashlib
+ h = hashlib.sha256()
+ b = bytearray(128*1024)
+ mv = memoryview(b)
+ with open(file_path, 'rb', buffering=0) as f:
+ while n := f.readinto(mv):
+ h.update(mv[:n])
+ return h.hexdigest()
+
+def check_hash_from_torch_hub(file_path, filename):
+ basename, _ = filename.split('.')
+ _, ref_hash = basename.split('-')
+ curr_hash = sha256sum(file_path)
+ return curr_hash[:len(ref_hash)] == ref_hash
+
+def custom_torch_download(filename, ckpts_dir=annotator_ckpts_path):
+ """Download PyTorch models using PyTorch 2.7's built-in download mechanism."""
+ model_url = "https://download.pytorch.org/models/" + filename
+
+ # Use PyTorch's built-in model downloading with custom cache directory
+ local_dir = os.path.join(ckpts_dir, "torch")
+ if not os.path.exists(local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+
+ model_path = os.path.join(local_dir, filename)
+
+ if not os.path.exists(model_path):
+ print(f"Downloading {filename} from pytorch.org...")
+ try:
+ # Use PyTorch 2.7's load_url which handles caching, progress, and hash checking
+ state_dict = load_url(model_url, model_dir=local_dir, file_name=filename, progress=True, check_hash=True)
+ # The file is already saved by load_url, we just need the path
+ except Exception as e:
+ warnings.warn(f"Download failed with error: {e}")
+ raise
+
+ print(f"model_path is {model_path}")
+ return model_path
+
+def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, ckpts_dir=annotator_ckpts_path, subfolder='', use_symlinks=USE_SYMLINKS, repo_type="model"):
+
+ local_dir = os.path.join(ckpts_dir, pretrained_model_or_path)
+ model_path = Path(local_dir).joinpath(*subfolder.split('/'), filename).__str__()
+
+ if len(str(model_path)) >= 255:
+ warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml")
+
+ if not os.path.exists(model_path):
+ print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
+ print(f"cacher folder is {cache_dir}, you can change it by custom_tmp_path in config.yaml")
+ if use_symlinks:
+ cache_dir_d = constants.HF_HUB_CACHE # use huggingface newer env variables `HF_HUB_CACHE`
+ if cache_dir_d is None:
+ import platform
+ if platform.system() == "Windows":
+ cache_dir_d = Path(os.getenv("USERPROFILE")).joinpath(".cache", "huggingface", "hub").__str__()
+ else:
+ cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub")
+ try:
+ # test_link
+ Path(cache_dir_d).mkdir(parents=True, exist_ok=True)
+ Path(ckpts_dir).mkdir(parents=True, exist_ok=True)
+ (Path(cache_dir_d) / f"linktest_{filename}.txt").touch()
+ # symlink instead of link avoid `invalid cross-device link` error.
+ os.symlink(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
+ print("Using symlinks to download models. \n",\
+ "Make sure you have enough space on your cache folder. \n",\
+ "And do not purge the cache folder after downloading.\n",\
+ "Otherwise, you will have to re-download the models every time you run the script.\n",\
+ "You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.")
+ except:
+ print("Maybe not able to create symlink. Disable using symlinks.")
+ use_symlinks = False
+ cache_dir_d = Path(cache_dir).joinpath("ckpts", pretrained_model_or_path).__str__()
+ finally: # always remove test link files
+ with suppress(FileNotFoundError):
+ os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt"))
+ os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
+ else:
+ cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path)
+
+ model_path = hf_hub_download(repo_id=pretrained_model_or_path,
+ cache_dir=cache_dir_d,
+ local_dir=local_dir,
+ subfolder=subfolder,
+ filename=filename,
+ local_dir_use_symlinks=use_symlinks,
+ resume_download=True,
+ etag_timeout=100,
+ repo_type=repo_type
+ )
+ if not use_symlinks:
+ try:
+ import shutil
+ shutil.rmtree(os.path.join(cache_dir, "ckpts"))
+ except Exception as e :
+ print(e)
+
+ print(f"model_path is {model_path}")
+
+ return model_path
diff --git a/src/custom_controlnet_aux/zoe/LICENSE b/src/custom_controlnet_aux/zoe/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab
--- /dev/null
+++ b/src/custom_controlnet_aux/zoe/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Intelligent Systems Lab Org
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/zoe/__init__.py b/src/custom_controlnet_aux/zoe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67761de606711828cb1c1e64d132f68c1fe9b42
--- /dev/null
+++ b/src/custom_controlnet_aux/zoe/__init__.py
@@ -0,0 +1,3 @@
+# Modern ZoeDepth implementation using HuggingFace transformers
+
+from .transformers import ZoeDetector, ZoeDepthAnythingDetector
\ No newline at end of file
diff --git a/src/custom_controlnet_aux/zoe/transformers.py b/src/custom_controlnet_aux/zoe/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..943ab6b7a56bcc6b39e15234a29ab1e51345c369
--- /dev/null
+++ b/src/custom_controlnet_aux/zoe/transformers.py
@@ -0,0 +1,169 @@
+"""
+ZoeDepth implementation using HuggingFace transformers.
+Uses official Intel models for depth estimation.
+"""
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import pipeline, AutoImageProcessor, ZoeDepthForDepthEstimation
+
+# Local utility functions
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+def safer_memory(x):
+ return np.ascontiguousarray(x.copy()).copy()
+
+def resize_image_with_pad(input_image, resolution, upscale_method="INTER_CUBIC", skip_hwc3=False, mode='edge'):
+ import cv2
+ if skip_hwc3:
+ img = input_image
+ else:
+ img = HWC3(input_image)
+ H_raw, W_raw, _ = img.shape
+ if resolution == 0:
+ return img, lambda x: x
+ k = float(resolution) / float(min(H_raw, W_raw))
+ H_target = int(np.round(float(H_raw) * k))
+ W_target = int(np.round(float(W_raw) * k))
+
+ upscale_methods = {"INTER_NEAREST": cv2.INTER_NEAREST, "INTER_LINEAR": cv2.INTER_LINEAR,
+ "INTER_AREA": cv2.INTER_AREA, "INTER_CUBIC": cv2.INTER_CUBIC,
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4}
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
+
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
+
+ def remove_pad(x):
+ return safer_memory(x[:H_target, :W_target, ...])
+
+ return safer_memory(img_padded), remove_pad
+
+def common_input_validate(input_image, output_type, **kwargs):
+ import warnings
+ if "img" in kwargs:
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
+ input_image = kwargs.pop("img")
+
+ if "return_pil" in kwargs:
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
+ output_type = "pil" if kwargs["return_pil"] else "np"
+
+ if type(output_type) is bool:
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
+ if output_type:
+ output_type = "pil"
+
+ if input_image is None:
+ raise ValueError("input_image must be defined.")
+
+ if not isinstance(input_image, np.ndarray):
+ input_image = np.array(input_image, dtype=np.uint8)
+ output_type = output_type or "pil"
+ else:
+ output_type = output_type or "np"
+
+ return (input_image, output_type)
+
+
+class ZoeDetector:
+ """ZoeDepth depth estimation using HuggingFace transformers."""
+
+ def __init__(self, model_name="Intel/zoedepth-nyu-kitti"):
+ """Initialize ZoeDepth with specified model."""
+ self.pipe = pipeline(task="depth-estimation", model=model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path="Intel/zoedepth-nyu-kitti", filename=None, **kwargs):
+ """Create ZoeDetector from pretrained model."""
+ return cls(model_name=pretrained_model_or_path)
+
+ def to(self, device):
+ """Move model to specified device."""
+ self.pipe.model = self.pipe.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ """Perform depth estimation on input image."""
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
+
+ if isinstance(input_image, np.ndarray):
+ pil_image = Image.fromarray(input_image)
+ else:
+ pil_image = input_image
+
+ with torch.no_grad():
+ result = self.pipe(pil_image)
+ depth = result["depth"]
+
+ if isinstance(depth, Image.Image):
+ depth_array = np.array(depth, dtype=np.float32)
+ else:
+ depth_array = np.array(depth)
+
+ vmin = np.percentile(depth_array, 2)
+ vmax = np.percentile(depth_array, 85)
+
+ depth_array = depth_array - vmin
+ depth_array = depth_array / (vmax - vmin)
+ depth_array = 1.0 - depth_array
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
+
+ detected_map = remove_pad(HWC3(depth_image))
+
+ if output_type == "pil":
+ detected_map = Image.fromarray(detected_map)
+
+ return detected_map
+
+
+class ZoeDepthAnythingDetector:
+ """ZoeDepthAnything implementation using HuggingFace transformers."""
+
+ def __init__(self, model_name="Intel/zoedepth-nyu-kitti"):
+ """Initialize ZoeDepthAnything detector."""
+ self.pipe = pipeline(task="depth-estimation", model=model_name)
+ self.device = "cpu"
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_or_path="Intel/zoedepth-nyu-kitti", filename=None, **kwargs):
+ """Create from pretrained model."""
+ return cls(model_name=pretrained_model_or_path)
+
+ def to(self, device):
+ """Move model to specified device."""
+ self.pipe.model = self.pipe.model.to(device)
+ self.device = device
+ return self
+
+ def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
+ """Perform depth estimation."""
+ detector = ZoeDetector(model_name="Intel/zoedepth-nyu-kitti")
+ detector.pipe = self.pipe
+ detector.device = self.device
+
+ return detector(input_image, detect_resolution, output_type, upscale_method, **kwargs)
\ No newline at end of file
diff --git a/src/custom_manopth/CHANGES.md b/src/custom_manopth/CHANGES.md
new file mode 100644
index 0000000000000000000000000000000000000000..27e7d74595a4b048f0e0aff1f77a7488870a821e
--- /dev/null
+++ b/src/custom_manopth/CHANGES.md
@@ -0,0 +1 @@
+* Chumpy is removed
\ No newline at end of file
diff --git a/src/custom_manopth/LICENSE b/src/custom_manopth/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e72bfddabc15be5718a7cc061ac10e47741d8219
--- /dev/null
+++ b/src/custom_manopth/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
\ No newline at end of file
diff --git a/src/custom_manopth/__init__.py b/src/custom_manopth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27cf8699e86b13d9d4cb28da10d54c405effe96
--- /dev/null
+++ b/src/custom_manopth/__init__.py
@@ -0,0 +1 @@
+name = 'manopth'
diff --git a/src/custom_manopth/argutils.py b/src/custom_manopth/argutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e86eb025ad0618e63d730b4f59ee3615118d197
--- /dev/null
+++ b/src/custom_manopth/argutils.py
@@ -0,0 +1,51 @@
+import datetime
+import os
+import pickle
+import subprocess
+import sys
+
+
+def print_args(args):
+ opts = vars(args)
+ print('======= Options ========')
+ for k, v in sorted(opts.items()):
+ print('{}: {}'.format(k, v))
+ print('========================')
+
+
+def save_args(args, save_folder, opt_prefix='opt', verbose=True):
+ opts = vars(args)
+ # Create checkpoint folder
+ if not os.path.exists(save_folder):
+ os.makedirs(save_folder, exist_ok=True)
+
+ # Save options
+ opt_filename = '{}.txt'.format(opt_prefix)
+ opt_path = os.path.join(save_folder, opt_filename)
+ with open(opt_path, 'a') as opt_file:
+ opt_file.write('====== Options ======\n')
+ for k, v in sorted(opts.items()):
+ opt_file.write(
+ '{option}: {value}\n'.format(option=str(k), value=str(v)))
+ opt_file.write('=====================\n')
+ opt_file.write('launched {} at {}\n'.format(
+ str(sys.argv[0]), str(datetime.datetime.now())))
+
+ # Add git info
+ label = subprocess.check_output(["git", "describe",
+ "--always"]).strip()
+ if subprocess.call(
+ ["git", "branch"],
+ stderr=subprocess.STDOUT,
+ stdout=open(os.devnull, 'w')) == 0:
+ opt_file.write('=== Git info ====\n')
+ opt_file.write('{}\n'.format(label))
+ commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+ opt_file.write('commit : {}\n'.format(commit.strip()))
+
+ opt_picklename = '{}.pkl'.format(opt_prefix)
+ opt_picklepath = os.path.join(save_folder, opt_picklename)
+ with open(opt_picklepath, 'wb') as opt_file:
+ pickle.dump(opts, opt_file)
+ if verbose:
+ print('Saved options to {}'.format(opt_path))
diff --git a/src/custom_manopth/demo.py b/src/custom_manopth/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a250b7c4d4e2fd2face4611d87839faf795eecd
--- /dev/null
+++ b/src/custom_manopth/demo.py
@@ -0,0 +1,59 @@
+from matplotlib import pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+from mpl_toolkits.mplot3d.art3d import Poly3DCollection
+import numpy as np
+import torch
+
+from custom_manopth.manolayer import ManoLayer
+
+
+def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'):
+ nfull_comps = ncomps + 3 # Add global orientation dims to PCA
+ random_pcapose = torch.rand(batch_size, nfull_comps)
+ mano_layer = ManoLayer(mano_root=mano_root)
+ verts, joints = mano_layer(random_pcapose)
+ return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces}
+
+
+def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True):
+ """
+ Displays hand batch_idx in batch of hand_info, hand_info as returned by
+ generate_random_hand
+ """
+ if ax is None:
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection='3d')
+ verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][
+ batch_idx]
+ if mano_faces is None:
+ ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1)
+ else:
+ mesh = Poly3DCollection(verts[mano_faces], alpha=alpha)
+ face_color = (141 / 255, 184 / 255, 226 / 255)
+ edge_color = (50 / 255, 50 / 255, 50 / 255)
+ mesh.set_edgecolor(edge_color)
+ mesh.set_facecolor(face_color)
+ ax.add_collection3d(mesh)
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
+ cam_equal_aspect_3d(ax, verts.numpy())
+ if show:
+ plt.show()
+
+
+def cam_equal_aspect_3d(ax, verts, flip_x=False):
+ """
+ Centers view on cuboid containing hand and flips y and z axis
+ and fixes azimuth
+ """
+ extents = np.stack([verts.min(0), verts.max(0)], axis=1)
+ sz = extents[:, 1] - extents[:, 0]
+ centers = np.mean(extents, axis=1)
+ maxsize = max(abs(sz))
+ r = maxsize / 2
+ if flip_x:
+ ax.set_xlim(centers[0] + r, centers[0] - r)
+ else:
+ ax.set_xlim(centers[0] - r, centers[0] + r)
+ # Invert y and z axis
+ ax.set_ylim(centers[1] + r, centers[1] - r)
+ ax.set_zlim(centers[2] + r, centers[2] - r)
diff --git a/src/custom_manopth/manolayer.py b/src/custom_manopth/manolayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..29440d96eb0e41ca7945add89e1df12e274f2b01
--- /dev/null
+++ b/src/custom_manopth/manolayer.py
@@ -0,0 +1,274 @@
+import os
+
+import numpy as np
+import torch
+from torch.nn import Module
+
+from custom_manopth.smpl_handpca_wrapper_HAND_only import ready_arguments
+from custom_manopth import rodrigues_layer, rotproj, rot6d
+from custom_manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack,
+ subtract_flat_id, make_list)
+
+
+class ManoLayer(Module):
+ __constants__ = [
+ 'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check',
+ 'side', 'center_idx', 'joint_rot_mode'
+ ]
+
+ def __init__(self,
+ center_idx=None,
+ flat_hand_mean=True,
+ ncomps=6,
+ side='right',
+ mano_root='mano/models',
+ use_pca=True,
+ root_rot_mode='axisang',
+ joint_rot_mode='axisang',
+ robust_rot=False):
+ """
+ Args:
+ center_idx: index of center joint in our computations,
+ if -1 centers on estimate of palm as middle of base
+ of middle finger and wrist
+ flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
+ flat hand, else match average hand pose
+ mano_root: path to MANO pkl files for left and right hand
+ ncomps: number of PCA components form pose space (<45)
+ side: 'right' or 'left'
+ use_pca: Use PCA decomposition for pose space.
+ joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca
+ """
+ super().__init__()
+
+ self.center_idx = center_idx
+ self.robust_rot = robust_rot
+ if root_rot_mode == 'axisang':
+ self.rot = 3
+ else:
+ self.rot = 6
+ self.flat_hand_mean = flat_hand_mean
+ self.side = side
+ self.use_pca = use_pca
+ self.joint_rot_mode = joint_rot_mode
+ self.root_rot_mode = root_rot_mode
+ if use_pca:
+ self.ncomps = ncomps
+ else:
+ self.ncomps = 45
+
+ if side == 'right':
+ self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
+ elif side == 'left':
+ self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
+
+ smpl_data = ready_arguments(self.mano_path)
+
+ hands_components = smpl_data['hands_components']
+
+ self.smpl_data = smpl_data
+
+ self.register_buffer('th_betas',
+ torch.Tensor(smpl_data['betas']).unsqueeze(0))
+ self.register_buffer('th_shapedirs',
+ torch.Tensor(smpl_data['shapedirs']))
+ self.register_buffer('th_posedirs',
+ torch.Tensor(smpl_data['posedirs']))
+ self.register_buffer(
+ 'th_v_template',
+ torch.Tensor(smpl_data['v_template']).unsqueeze(0))
+ self.register_buffer(
+ 'th_J_regressor',
+ torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
+ self.register_buffer('th_weights',
+ torch.Tensor(smpl_data['weights']))
+ self.register_buffer('th_faces',
+ torch.Tensor(smpl_data['f'].astype(np.int32)).long())
+
+ # Get hand mean
+ hands_mean = np.zeros(hands_components.shape[1]
+ ) if flat_hand_mean else smpl_data['hands_mean']
+ hands_mean = hands_mean.copy()
+ th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
+ if self.use_pca or self.joint_rot_mode == 'axisang':
+ # Save as axis-angle
+ self.register_buffer('th_hands_mean', th_hands_mean)
+ selected_components = hands_components[:ncomps]
+ self.register_buffer('th_comps', torch.Tensor(hands_components))
+ self.register_buffer('th_selected_comps',
+ torch.Tensor(selected_components))
+ else:
+ th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
+ th_hands_mean.view(15, 3)).reshape(15, 3, 3)
+ self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat)
+
+ # Kinematic chain params
+ self.kintree_table = smpl_data['kintree_table']
+ parents = list(self.kintree_table[0].tolist())
+ self.kintree_parents = parents
+
+ def forward(self,
+ th_pose_coeffs,
+ th_betas=torch.zeros(1),
+ th_trans=torch.zeros(1),
+ root_palm=torch.Tensor([0]),
+ share_betas=torch.Tensor([0]),
+ ):
+ """
+ Args:
+ th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices
+ th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape
+ else centers on root joint (9th joint)
+ root_palm: return palm as hand root instead of wrist
+ """
+ # if len(th_pose_coeffs) == 0:
+ # return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0)
+
+ batch_size = th_pose_coeffs.shape[0]
+ # Get axis angle from PCA components and coefficients
+ if self.use_pca or self.joint_rot_mode == 'axisang':
+ # Remove global rot coeffs
+ th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot +
+ self.ncomps]
+ if self.use_pca:
+ # PCA components --> axis angles
+ th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps)
+ else:
+ th_full_hand_pose = th_hand_pose_coeffs
+
+ # Concatenate back global rot
+ th_full_pose = torch.cat([
+ th_pose_coeffs[:, :self.rot],
+ self.th_hands_mean + th_full_hand_pose
+ ], 1)
+ if self.root_rot_mode == 'axisang':
+ # compute rotation matrixes from axis-angle while skipping global rotation
+ th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
+ root_rot = th_rot_map[:, :9].view(batch_size, 3, 3)
+ th_rot_map = th_rot_map[:, 9:]
+ th_pose_map = th_pose_map[:, 9:]
+ else:
+ # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6
+ th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:])
+ if self.robust_rot:
+ root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
+ else:
+ root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
+ else:
+ assert th_pose_coeffs.dim() == 4, (
+ 'When not self.use_pca, '
+ 'th_pose_coeffs should have 4 dims, got {}'.format(
+ th_pose_coeffs.dim()))
+ assert th_pose_coeffs.shape[2:4] == (3, 3), (
+ 'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two'
+ 'last dims, got {}'.format(th_pose_coeffs.shape[2:4]))
+ th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs)
+ th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1)
+ th_pose_map = subtract_flat_id(th_rot_map)
+ root_rot = th_pose_rots[:, 0]
+
+ # Full axis angle representation with root joint
+ if th_betas is None or th_betas.numel() == 1:
+ th_v_shaped = torch.matmul(self.th_shapedirs,
+ self.th_betas.transpose(1, 0)).permute(
+ 2, 0, 1) + self.th_v_template
+ th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
+ batch_size, 1, 1)
+
+ else:
+ if share_betas:
+ th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10)
+ th_v_shaped = torch.matmul(self.th_shapedirs,
+ th_betas.transpose(1, 0)).permute(
+ 2, 0, 1) + self.th_v_template
+ th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
+ # th_pose_map should have shape 20x135
+
+ th_v_posed = th_v_shaped + torch.matmul(
+ self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
+ # Final T pose with transformation done !
+
+ # Global rigid transformation
+
+ root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
+ root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2))
+
+ all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3)
+ lev1_idxs = [1, 4, 7, 10, 13]
+ lev2_idxs = [2, 5, 8, 11, 14]
+ lev3_idxs = [3, 6, 9, 12, 15]
+ lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
+ lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
+ lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
+ lev1_j = th_j[:, lev1_idxs]
+ lev2_j = th_j[:, lev2_idxs]
+ lev3_j = th_j[:, lev3_idxs]
+
+ # From base to tips
+ # Get lev1 results
+ all_transforms = [root_trans.unsqueeze(1)]
+ lev1_j_rel = lev1_j - root_j.transpose(1, 2)
+ lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
+ root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4)
+ lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt)
+ all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4))
+
+ # Get lev2 results
+ lev2_j_rel = lev2_j - lev1_j
+ lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
+ lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt)
+ all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4))
+
+ # Get lev3 results
+ lev3_j_rel = lev3_j - lev2_j
+ lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
+ lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt)
+ all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4))
+
+ reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
+ th_results = torch.cat(all_transforms, 1)[:, reorder_idxs]
+ th_results_global = th_results
+
+ joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2)
+ tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3))
+ th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1)
+
+ th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))
+
+ th_rest_shape_h = torch.cat([
+ th_v_posed.transpose(2, 1),
+ torch.ones((batch_size, 1, th_v_posed.shape[1]),
+ dtype=th_T.dtype,
+ device=th_T.device),
+ ], 1)
+
+ th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
+ th_verts = th_verts[:, :, :3]
+ th_jtr = th_results_global[:, :, :3, 3]
+ # In addition to MANO reference joints we sample vertices on each finger
+ # to serve as finger tips
+ if self.side == 'right':
+ tips = th_verts[:, [745, 317, 444, 556, 673]]
+ else:
+ tips = th_verts[:, [745, 317, 445, 556, 673]]
+ if bool(root_palm):
+ palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2
+ th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1)
+ th_jtr = torch.cat([th_jtr, tips], 1)
+
+ # Reorder joints to match visualization utilities
+ th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
+
+ if th_trans is None or bool(torch.norm(th_trans) == 0):
+ if self.center_idx is not None:
+ center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
+ th_jtr = th_jtr - center_joint
+ th_verts = th_verts - center_joint
+ else:
+ th_jtr = th_jtr + th_trans.unsqueeze(1)
+ th_verts = th_verts + th_trans.unsqueeze(1)
+
+ # Scale to milimeters
+ th_verts = th_verts * 1000
+ th_jtr = th_jtr * 1000
+ return th_verts, th_jtr
diff --git a/src/custom_manopth/posemapper.py b/src/custom_manopth/posemapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b86ea0a5bfbcf5f2f59a6874fb1bd0d8aa99e6d
--- /dev/null
+++ b/src/custom_manopth/posemapper.py
@@ -0,0 +1,37 @@
+'''
+Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
+This software is provided for research purposes only.
+By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
+
+More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
+For comments or questions, please email us at: mano@tue.mpg.de
+
+
+About this file:
+================
+This file defines a wrapper for the loading functions of the MANO model.
+
+Modules included:
+- load_model:
+ loads the MANO model from a given file location (i.e. a .pkl file location),
+ or a dictionary object.
+
+'''
+
+
+import numpy as np
+import cv2
+
+def lrotmin(p):
+ if isinstance(p, np.ndarray):
+ p = p.ravel()[3:]
+ return np.concatenate(
+ [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel()
+ for pp in p.reshape((-1, 3))]).ravel()
+
+
+def posemap(s):
+ if s == 'lrotmin':
+ return lrotmin
+ else:
+ raise Exception('Unknown posemapping: %s' % (str(s), ))
\ No newline at end of file
diff --git a/src/custom_manopth/rodrigues_layer.py b/src/custom_manopth/rodrigues_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb7a2565a6c7ccf185446c157bd2599943d1832c
--- /dev/null
+++ b/src/custom_manopth/rodrigues_layer.py
@@ -0,0 +1,89 @@
+"""
+This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py
+which is part of a PyTorch port of SMPL.
+Thanks to Zhang Xiong (MandyMo) for making this great code available on github !
+"""
+
+import argparse
+from torch.autograd import gradcheck
+import torch
+from torch.autograd import Variable
+
+from custom_manopth import argutils
+
+
+def quat2mat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
+ 2], norm_quat[:,
+ 3]
+
+ batch_size = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+
+ rotMat = torch.stack([
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
+ w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
+ w2 - x2 - y2 + z2
+ ],
+ dim=1).view(batch_size, 3, 3)
+ return rotMat
+
+
+def batch_rodrigues(axisang):
+ #axisang N x 3
+ axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(axisang_norm, -1)
+ axisang_normalized = torch.div(axisang, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
+ rot_mat = quat2mat(quat)
+ rot_mat = rot_mat.view(rot_mat.shape[0], 9)
+ return rot_mat
+
+
+def th_get_axis_angle(vector):
+ angle = torch.norm(vector, 2, 1)
+ axes = vector / angle.unsqueeze(1)
+ return axes, angle
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--batch_size', default=1, type=int)
+ parser.add_argument('--cuda', action='store_true')
+ args = parser.parse_args()
+
+ argutils.print_args(args)
+
+ n_components = 6
+ rot = 3
+ inputs = torch.rand(args.batch_size, rot)
+ inputs_var = Variable(inputs.double(), requires_grad=True)
+ if args.cuda:
+ inputs = inputs.cuda()
+ # outputs = batch_rodrigues(inputs)
+ test_function = gradcheck(batch_rodrigues, (inputs_var, ))
+ print('batch test passed !')
+
+ inputs = torch.rand(rot)
+ inputs_var = Variable(inputs.double(), requires_grad=True)
+ test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, ))
+ print('th_cv2_rod test passed')
+
+ inputs = torch.rand(rot)
+ inputs_var = Variable(inputs.double(), requires_grad=True)
+ test_th = gradcheck(th_cv2_rod.apply, (inputs_var, ))
+ print('th_cv2_rod_id test passed !')
diff --git a/src/custom_manopth/rot6d.py b/src/custom_manopth/rot6d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d60efbcfadda5f216c0eb9a60b348e248435a0
--- /dev/null
+++ b/src/custom_manopth/rot6d.py
@@ -0,0 +1,71 @@
+import torch
+
+
+def compute_rotation_matrix_from_ortho6d(poses):
+ """
+ Code from
+ https://github.com/papagina/RotationContinuity
+ On the Continuity of Rotation Representations in Neural Networks
+ Zhou et al. CVPR19
+ https://zhouyisjtu.github.io/project_rotation/rotation.html
+ """
+ x_raw = poses[:, 0:3] # batch*3
+ y_raw = poses[:, 3:6] # batch*3
+
+ x = normalize_vector(x_raw) # batch*3
+ z = cross_product(x, y_raw) # batch*3
+ z = normalize_vector(z) # batch*3
+ y = cross_product(z, x) # batch*3
+
+ x = x.view(-1, 3, 1)
+ y = y.view(-1, 3, 1)
+ z = z.view(-1, 3, 1)
+ matrix = torch.cat((x, y, z), 2) # batch*3*3
+ return matrix
+
+def robust_compute_rotation_matrix_from_ortho6d(poses):
+ """
+ Instead of making 2nd vector orthogonal to first
+ create a base that takes into account the two predicted
+ directions equally
+ """
+ x_raw = poses[:, 0:3] # batch*3
+ y_raw = poses[:, 3:6] # batch*3
+
+ x = normalize_vector(x_raw) # batch*3
+ y = normalize_vector(y_raw) # batch*3
+ middle = normalize_vector(x + y)
+ orthmid = normalize_vector(x - y)
+ x = normalize_vector(middle + orthmid)
+ y = normalize_vector(middle - orthmid)
+ # Their scalar product should be small !
+ # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001
+ z = normalize_vector(cross_product(x, y))
+
+ x = x.view(-1, 3, 1)
+ y = y.view(-1, 3, 1)
+ z = z.view(-1, 3, 1)
+ matrix = torch.cat((x, y, z), 2) # batch*3*3
+ # Check for reflection in matrix ! If found, flip last vector TODO
+ assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0
+ return matrix
+
+
+def normalize_vector(v):
+ batch = v.shape[0]
+ v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
+ v_mag = torch.max(v_mag, v.new([1e-8]))
+ v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
+ v = v/v_mag
+ return v
+
+
+def cross_product(u, v):
+ batch = u.shape[0]
+ i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
+ j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
+ k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
+
+ out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)
+
+ return out
diff --git a/src/custom_manopth/rotproj.py b/src/custom_manopth/rotproj.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a601d5de8117ed9fe1708c2d11e8713dc18011
--- /dev/null
+++ b/src/custom_manopth/rotproj.py
@@ -0,0 +1,21 @@
+import torch
+
+
+def batch_rotprojs(batches_rotmats):
+ proj_rotmats = []
+ for batch_idx, batch_rotmats in enumerate(batches_rotmats):
+ proj_batch_rotmats = []
+ for rot_idx, rotmat in enumerate(batch_rotmats):
+ # GPU implementation of svd is VERY slow
+ # ~ 2 10^-3 per hit vs 5 10^-5 on cpu
+ U, S, V = rotmat.cpu().svd()
+ rotmat = torch.matmul(U, V.transpose(0, 1))
+ orth_det = rotmat.det()
+ # Remove reflection
+ if orth_det < 0:
+ rotmat[:, 2] = -1 * rotmat[:, 2]
+
+ rotmat = rotmat.cuda()
+ proj_batch_rotmats.append(rotmat)
+ proj_rotmats.append(torch.stack(proj_batch_rotmats))
+ return torch.stack(proj_rotmats)
diff --git a/src/custom_manopth/smpl_handpca_wrapper_HAND_only.py b/src/custom_manopth/smpl_handpca_wrapper_HAND_only.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2d157625974deb10f9f77f95de1f1d369ecbf8f
--- /dev/null
+++ b/src/custom_manopth/smpl_handpca_wrapper_HAND_only.py
@@ -0,0 +1,155 @@
+'''
+Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
+This software is provided for research purposes only.
+By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
+
+More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
+For comments or questions, please email us at: mano@tue.mpg.de
+
+
+About this file:
+================
+This file defines a wrapper for the loading functions of the MANO model.
+
+Modules included:
+- load_model:
+ loads the MANO model from a given file location (i.e. a .pkl file location),
+ or a dictionary object.
+
+'''
+
+def col(A):
+ return A.reshape((-1, 1))
+
+def MatVecMult(mtx, vec):
+ result = mtx.dot(col(vec.ravel())).ravel()
+ if len(vec.shape) > 1 and vec.shape[1] > 1:
+ result = result.reshape((-1, vec.shape[1]))
+ return result
+
+def ready_arguments(fname_or_dict, posekey4vposed='pose'):
+ import numpy as np
+ import pickle
+ from custom_manopth.posemapper import posemap
+
+ if not isinstance(fname_or_dict, dict):
+ dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
+ # dd = pickle.load(open(fname_or_dict, 'rb'))
+ else:
+ dd = fname_or_dict
+
+ want_shapemodel = 'shapedirs' in dd
+ nposeparms = dd['kintree_table'].shape[1] * 3
+
+ if 'trans' not in dd:
+ dd['trans'] = np.zeros(3)
+ if 'pose' not in dd:
+ dd['pose'] = np.zeros(nposeparms)
+ if 'shapedirs' in dd and 'betas' not in dd:
+ dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
+
+ for s in [
+ 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
+ 'betas', 'J'
+ ]:
+ if (s in dd) and not hasattr(dd[s], 'dterms'):
+ dd[s] = np.array(dd[s])
+
+ assert (posekey4vposed in dd)
+ if want_shapemodel:
+ dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
+ v_shaped = dd['v_shaped']
+ J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
+ J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
+ J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
+ dd['J'] = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
+ pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
+ dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res)
+ else:
+ pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
+ dd_add = dd['posedirs'].dot(pose_map_res)
+ dd['v_posed'] = dd['v_template'] + dd_add
+
+ return dd
+
+
+def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None):
+ ''' This model loads the fully articulable HAND SMPL model,
+ and replaces the pose DOFS by ncomps from PCA'''
+
+ from custom_manopth.verts import verts_core
+ import numpy as np
+ import pickle
+ import scipy.sparse as sp
+ np.random.seed(1)
+
+ if not isinstance(fname_or_dict, dict):
+ smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
+ # smpl_data = pickle.load(open(fname_or_dict, 'rb'))
+ else:
+ smpl_data = fname_or_dict
+
+ rot = 3 # for global orientation!!!
+
+ hands_components = smpl_data['hands_components']
+ hands_mean = np.zeros(hands_components.shape[
+ 1]) if flat_hand_mean else smpl_data['hands_mean']
+ hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps]
+
+ selected_components = np.vstack((hands_components[:ncomps]))
+ hands_mean = hands_mean.copy()
+
+ pose_coeffs = np.zeros(rot + selected_components.shape[0])
+ full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components)
+
+ smpl_data['fullpose'] = np.concatenate((pose_coeffs[:rot],
+ hands_mean + full_hand_pose))
+ smpl_data['pose'] = pose_coeffs
+
+ Jreg = smpl_data['J_regressor']
+ if not sp.issparse(Jreg):
+ smpl_data['J_regressor'] = (sp.csc_matrix(
+ (Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape))
+
+ # slightly modify ready_arguments to make sure that it uses the fullpose
+ # (which will NOT be pose) for the computation of posedirs
+ dd = ready_arguments(smpl_data, posekey4vposed='fullpose')
+
+ # create the smpl formula with the fullpose,
+ # but expose the PCA coefficients as smpl.pose for compatibility
+ args = {
+ 'pose': dd['fullpose'],
+ 'v': dd['v_posed'],
+ 'J': dd['J'],
+ 'weights': dd['weights'],
+ 'kintree_table': dd['kintree_table'],
+ 'xp': np,
+ 'want_Jtr': True,
+ 'bs_style': dd['bs_style'],
+ }
+
+ result_previous, meta = verts_core(**args)
+
+ result = result_previous + dd['trans'].reshape((1, 3))
+ result.no_translation = result_previous
+
+ if meta is not None:
+ for field in ['Jtr', 'A', 'A_global', 'A_weighted']:
+ if (hasattr(meta, field)):
+ setattr(result, field, getattr(meta, field))
+
+ setattr(result, 'Jtr', meta)
+ if hasattr(result, 'Jtr'):
+ result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3))
+
+ for k, v in dd.items():
+ setattr(result, k, v)
+
+ if v_template is not None:
+ result.v_template[:] = v_template
+
+ return result
+
+
+if __name__ == '__main__':
+ load_model()
\ No newline at end of file
diff --git a/src/custom_manopth/tensutils.py b/src/custom_manopth/tensutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce627beb5b4bd607a0e7ecb2d3a46814199e7b89
--- /dev/null
+++ b/src/custom_manopth/tensutils.py
@@ -0,0 +1,47 @@
+import torch
+
+from custom_manopth import rodrigues_layer
+
+
+def th_posemap_axisang(pose_vectors):
+ rot_nb = int(pose_vectors.shape[1] / 3)
+ pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3)
+ rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped)
+ rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9)
+ pose_maps = subtract_flat_id(rot_mats)
+ return pose_maps, rot_mats
+
+
+def th_with_zeros(tensor):
+ batch_size = tensor.shape[0]
+ padding = torch.tensor([0.0, 0.0, 0.0, 1.0], device = tensor.device, dtype = tensor.dtype)
+ padding.requires_grad = False
+
+ concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)]
+ cat_res = torch.cat(concat_list, 1)
+ return cat_res
+
+
+def th_pack(tensor):
+ batch_size = tensor.shape[0]
+ padding = tensor.new_zeros((batch_size, 4, 3))
+ padding.requires_grad = False
+ pack_list = [padding, tensor]
+ pack_res = torch.cat(pack_list, 2)
+ return pack_res
+
+
+def subtract_flat_id(rot_mats):
+ # Subtracts identity as a flattened tensor
+ rot_nb = int(rot_mats.shape[1] / 9)
+ id_flat = torch.eye(
+ 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat(
+ rot_mats.shape[0], rot_nb)
+ # id_flat.requires_grad = False
+ results = rot_mats - id_flat
+ return results
+
+
+def make_list(tensor):
+ # type: (List[int]) -> List[int]
+ return tensor
diff --git a/src/custom_manopth/verts.py b/src/custom_manopth/verts.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a9e5c320544a50b1c3c2d35dce0cb9967ed0ee9
--- /dev/null
+++ b/src/custom_manopth/verts.py
@@ -0,0 +1,117 @@
+'''
+Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
+This software is provided for research purposes only.
+By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
+
+More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
+For comments or questions, please email us at: mano@tue.mpg.de
+
+
+About this file:
+================
+This file defines a wrapper for the loading functions of the MANO model.
+
+Modules included:
+- load_model:
+ loads the MANO model from a given file location (i.e. a .pkl file location),
+ or a dictionary object.
+
+'''
+
+
+import numpy as np
+import mano.webuser.lbs as lbs
+from mano.webuser.posemapper import posemap
+import scipy.sparse as sp
+
+
+def ischumpy(x):
+ return hasattr(x, 'dterms')
+
+
+def verts_decorated(trans,
+ pose,
+ v_template,
+ J_regressor,
+ weights,
+ kintree_table,
+ bs_style,
+ f,
+ bs_type=None,
+ posedirs=None,
+ betas=None,
+ shapedirs=None,
+ want_Jtr=False):
+
+ for which in [
+ trans, pose, v_template, weights, posedirs, betas, shapedirs
+ ]:
+ if which is not None:
+ assert ischumpy(which)
+
+ v = v_template
+
+ if shapedirs is not None:
+ if betas is None:
+ betas = np.zeros(shapedirs.shape[-1])
+ v_shaped = v + shapedirs.dot(betas)
+ else:
+ v_shaped = v
+
+ if posedirs is not None:
+ v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose))
+ else:
+ v_posed = v_shaped
+
+ v = v_posed
+
+ if sp.issparse(J_regressor):
+ J_tmpx = np.matmul(J_regressor, v_shaped[:, 0])
+ J_tmpy = np.matmul(J_regressor, v_shaped[:, 1])
+ J_tmpz = np.matmul(J_regressor, v_shaped[:, 2])
+ J = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
+ else:
+ assert (ischumpy(J))
+
+ assert (bs_style == 'lbs')
+ result, Jtr = lbs.verts_core(
+ pose, v, J, weights, kintree_table, want_Jtr=True, xp=np)
+
+ tr = trans.reshape((1, 3))
+ result = result + tr
+ Jtr = Jtr + tr
+
+ result.trans = trans
+ result.f = f
+ result.pose = pose
+ result.v_template = v_template
+ result.J = J
+ result.J_regressor = J_regressor
+ result.weights = weights
+ result.kintree_table = kintree_table
+ result.bs_style = bs_style
+ result.bs_type = bs_type
+ if posedirs is not None:
+ result.posedirs = posedirs
+ result.v_posed = v_posed
+ if shapedirs is not None:
+ result.shapedirs = shapedirs
+ result.betas = betas
+ result.v_shaped = v_shaped
+ if want_Jtr:
+ result.J_transformed = Jtr
+ return result
+
+
+def verts_core(pose,
+ v,
+ J,
+ weights,
+ kintree_table,
+ bs_style,
+ want_Jtr=False,
+ xp=np):
+
+ assert (bs_style == 'lbs')
+ result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp)
+ return result
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/__init__.py b/src/custom_mesh_graphormer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc1f76bc69e3f559bee6253b24fc93acee9e1f9
--- /dev/null
+++ b/src/custom_mesh_graphormer/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.1.0"
diff --git a/src/custom_mesh_graphormer/datasets/__init__.py b/src/custom_mesh_graphormer/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/src/custom_mesh_graphormer/datasets/__init__.py
@@ -0,0 +1 @@
+
diff --git a/src/custom_mesh_graphormer/datasets/build.py b/src/custom_mesh_graphormer/datasets/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91c283074eb264c8780568108a7d58435a4b19f
--- /dev/null
+++ b/src/custom_mesh_graphormer/datasets/build.py
@@ -0,0 +1,147 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+
+import os.path as op
+import torch
+import logging
+import code
+from custom_mesh_graphormer.utils.comm import get_world_size
+from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
+from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
+
+
+def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
+ print(yaml_file)
+ if not op.isfile(yaml_file):
+ yaml_file = op.join(args.data_dir, yaml_file)
+ # code.interact(local=locals())
+ assert op.isfile(yaml_file)
+ return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
+
+
+class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
+ """
+ Wraps a BatchSampler, resampling from it until
+ a specified number of iterations have been sampled
+ """
+
+ def __init__(self, batch_sampler, num_iterations, start_iter=0):
+ self.batch_sampler = batch_sampler
+ self.num_iterations = num_iterations
+ self.start_iter = start_iter
+
+ def __iter__(self):
+ iteration = self.start_iter
+ while iteration <= self.num_iterations:
+ # if the underlying sampler has a set_epoch method, like
+ # DistributedSampler, used for making each process see
+ # a different split of the dataset, then set it
+ if hasattr(self.batch_sampler.sampler, "set_epoch"):
+ self.batch_sampler.sampler.set_epoch(iteration)
+ for batch in self.batch_sampler:
+ iteration += 1
+ if iteration > self.num_iterations:
+ break
+ yield batch
+
+ def __len__(self):
+ return self.num_iterations
+
+
+def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
+ batch_sampler = torch.utils.data.sampler.BatchSampler(
+ sampler, images_per_gpu, drop_last=False
+ )
+ if num_iters is not None and num_iters >= 0:
+ batch_sampler = IterationBasedBatchSampler(
+ batch_sampler, num_iters, start_iter
+ )
+ return batch_sampler
+
+
+def make_data_sampler(dataset, shuffle, distributed):
+ if distributed:
+ return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
+ if shuffle:
+ sampler = torch.utils.data.sampler.RandomSampler(dataset)
+ else:
+ sampler = torch.utils.data.sampler.SequentialSampler(dataset)
+ return sampler
+
+
+def make_data_loader(args, yaml_file, is_distributed=True,
+ is_train=True, start_iter=0, scale_factor=1):
+
+ dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
+ logger = logging.getLogger(__name__)
+ if is_train==True:
+ shuffle = True
+ images_per_gpu = args.per_gpu_train_batch_size
+ images_per_batch = images_per_gpu * get_world_size()
+ iters_per_batch = len(dataset) // images_per_batch
+ num_iters = iters_per_batch * args.num_train_epochs
+ logger.info("Train with {} images per GPU.".format(images_per_gpu))
+ logger.info("Total batch size {}".format(images_per_batch))
+ logger.info("Total training steps {}".format(num_iters))
+ else:
+ shuffle = False
+ images_per_gpu = args.per_gpu_eval_batch_size
+ num_iters = None
+ start_iter = 0
+
+ sampler = make_data_sampler(dataset, shuffle, is_distributed)
+ batch_sampler = make_batch_data_sampler(
+ sampler, images_per_gpu, num_iters, start_iter
+ )
+ data_loader = torch.utils.data.DataLoader(
+ dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
+ pin_memory=True,
+ )
+ return data_loader
+
+
+#==============================================================================================
+
+def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
+ print(yaml_file)
+ if not op.isfile(yaml_file):
+ yaml_file = op.join(args.data_dir, yaml_file)
+ # code.interact(local=locals())
+ assert op.isfile(yaml_file)
+ return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
+
+
+def make_hand_data_loader(args, yaml_file, is_distributed=True,
+ is_train=True, start_iter=0, scale_factor=1):
+
+ dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
+ logger = logging.getLogger(__name__)
+ if is_train==True:
+ shuffle = True
+ images_per_gpu = args.per_gpu_train_batch_size
+ images_per_batch = images_per_gpu * get_world_size()
+ iters_per_batch = len(dataset) // images_per_batch
+ num_iters = iters_per_batch * args.num_train_epochs
+ logger.info("Train with {} images per GPU.".format(images_per_gpu))
+ logger.info("Total batch size {}".format(images_per_batch))
+ logger.info("Total training steps {}".format(num_iters))
+ else:
+ shuffle = False
+ images_per_gpu = args.per_gpu_eval_batch_size
+ num_iters = None
+ start_iter = 0
+
+ sampler = make_data_sampler(dataset, shuffle, is_distributed)
+ batch_sampler = make_batch_data_sampler(
+ sampler, images_per_gpu, num_iters, start_iter
+ )
+ data_loader = torch.utils.data.DataLoader(
+ dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
+ pin_memory=True,
+ )
+ return data_loader
+
diff --git a/src/custom_mesh_graphormer/datasets/hand_mesh_tsv.py b/src/custom_mesh_graphormer/datasets/hand_mesh_tsv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4a839200a50aa353689a022e215e2013890a66c
--- /dev/null
+++ b/src/custom_mesh_graphormer/datasets/hand_mesh_tsv.py
@@ -0,0 +1,334 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+
+import cv2
+import math
+import json
+from PIL import Image
+import os.path as op
+import numpy as np
+import code
+
+from custom_mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
+from custom_mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
+from custom_mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
+import torch
+import torchvision.transforms as transforms
+
+
+class HandMeshTSVDataset(object):
+ def __init__(self, args, img_file, label_file=None, hw_file=None,
+ linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
+
+ self.args = args
+ self.img_file = img_file
+ self.label_file = label_file
+ self.hw_file = hw_file
+ self.linelist_file = linelist_file
+ self.img_tsv = self.get_tsv_file(img_file)
+ self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
+ self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
+
+ if self.is_composite:
+ assert op.isfile(self.linelist_file)
+ self.line_list = [i for i in range(self.hw_tsv.num_rows())]
+ else:
+ self.line_list = load_linelist_file(linelist_file)
+
+ self.cv2_output = cv2_output
+ self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ self.is_train = is_train
+ self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
+ self.noise_factor = 0.4
+ self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor]
+ self.img_res = 224
+ self.image_keys = self.prepare_image_keys()
+ self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
+ 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
+ self.root_index = self.joints_definition.index('Wrist')
+
+ def get_tsv_file(self, tsv_file):
+ if tsv_file:
+ if self.is_composite:
+ return CompositeTSVFile(tsv_file, self.linelist_file,
+ root=self.root)
+ tsv_path = find_file_path_in_yaml(tsv_file, self.root)
+ return TSVFile(tsv_path)
+
+ def get_valid_tsv(self):
+ # sorted by file size
+ if self.hw_tsv:
+ return self.hw_tsv
+ if self.label_tsv:
+ return self.label_tsv
+
+ def prepare_image_keys(self):
+ tsv = self.get_valid_tsv()
+ return [tsv.get_key(i) for i in range(tsv.num_rows())]
+
+ def prepare_image_key_to_index(self):
+ tsv = self.get_valid_tsv()
+ return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
+
+
+ def augm_params(self):
+ """Get augmentation parameters."""
+ flip = 0 # flipping
+ pn = np.ones(3) # per channel pixel-noise
+
+ if self.args.multiscale_inference == False:
+ rot = 0 # rotation
+ sc = 1.0 # scaling
+ elif self.args.multiscale_inference == True:
+ rot = self.args.rot
+ sc = self.args.sc
+
+ if self.is_train:
+ sc = 1.0
+ # Each channel is multiplied with a number
+ # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
+ pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
+
+ # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
+ rot = min(2*self.rot_factor,
+ max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
+
+ # The scale is multiplied with a number
+ # in the area [1-scaleFactor,1+scaleFactor]
+ sc = min(1+self.scale_factor,
+ max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
+ # but it is zero with probability 3/5
+ if np.random.uniform() <= 0.6:
+ rot = 0
+
+ return flip, pn, rot, sc
+
+ def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
+ """Process rgb image and do augmentation."""
+ rgb_img = crop(rgb_img, center, scale,
+ [self.img_res, self.img_res], rot=rot)
+ # flip the image
+ if flip:
+ rgb_img = flip_img(rgb_img)
+ # in the rgb image we add pixel noise in a channel-wise manner
+ rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
+ rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
+ rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
+ # (3,224,224),float,[0,1]
+ rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
+ return rgb_img
+
+ def j2d_processing(self, kp, center, scale, r, f):
+ """Process gt 2D keypoints and apply all augmentation transforms."""
+ nparts = kp.shape[0]
+ for i in range(nparts):
+ kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
+ [self.img_res, self.img_res], rot=r)
+ # convert to normalized coordinates
+ kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
+ # flip the x coordinates
+ if f:
+ kp = flip_kp(kp)
+ kp = kp.astype('float32')
+ return kp
+
+
+ def j3d_processing(self, S, r, f):
+ """Process gt 3D keypoints and apply all augmentation transforms."""
+ # in-plane rotation
+ rot_mat = np.eye(3)
+ if not r == 0:
+ rot_rad = -r * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
+ # flip the x coordinates
+ if f:
+ S = flip_kp(S)
+ S = S.astype('float32')
+ return S
+
+ def pose_processing(self, pose, r, f):
+ """Process SMPL theta parameters and apply all augmentation transforms."""
+ # rotation or the pose parameters
+ pose = pose.astype('float32')
+ pose[:3] = rot_aa(pose[:3], r)
+ # flip the pose parameters
+ if f:
+ pose = flip_pose(pose)
+ # (72),float
+ pose = pose.astype('float32')
+ return pose
+
+ def get_line_no(self, idx):
+ return idx if self.line_list is None else self.line_list[idx]
+
+ def get_image(self, idx):
+ line_no = self.get_line_no(idx)
+ row = self.img_tsv[line_no]
+ # use -1 to support old format with multiple columns.
+ cv2_im = img_from_base64(row[-1])
+ if self.cv2_output:
+ return cv2_im.astype(np.float32, copy=True)
+ cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
+ return cv2_im
+
+ def get_annotations(self, idx):
+ line_no = self.get_line_no(idx)
+ if self.label_tsv is not None:
+ row = self.label_tsv[line_no]
+ annotations = json.loads(row[1])
+ return annotations
+ else:
+ return []
+
+ def get_target_from_annotations(self, annotations, img_size, idx):
+ # This function will be overwritten by each dataset to
+ # decode the labels to specific formats for each task.
+ return annotations
+
+ def get_img_info(self, idx):
+ if self.hw_tsv is not None:
+ line_no = self.get_line_no(idx)
+ row = self.hw_tsv[line_no]
+ try:
+ # json string format with "height" and "width" being the keys
+ return json.loads(row[1])[0]
+ except ValueError:
+ # list of strings representing height and width in order
+ hw_str = row[1].split(' ')
+ hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
+ return hw_dict
+
+ def get_img_key(self, idx):
+ line_no = self.get_line_no(idx)
+ # based on the overhead of reading each row.
+ if self.hw_tsv:
+ return self.hw_tsv[line_no][0]
+ elif self.label_tsv:
+ return self.label_tsv[line_no][0]
+ else:
+ return self.img_tsv[line_no][0]
+
+ def __len__(self):
+ if self.line_list is None:
+ return self.img_tsv.num_rows()
+ else:
+ return len(self.line_list)
+
+ def __getitem__(self, idx):
+
+ img = self.get_image(idx)
+ img_key = self.get_img_key(idx)
+ annotations = self.get_annotations(idx)
+
+ annotations = annotations[0]
+ center = annotations['center']
+ scale = annotations['scale']
+ has_2d_joints = annotations['has_2d_joints']
+ has_3d_joints = annotations['has_3d_joints']
+ joints_2d = np.asarray(annotations['2d_joints'])
+ joints_3d = np.asarray(annotations['3d_joints'])
+
+ if joints_2d.ndim==3:
+ joints_2d = joints_2d[0]
+ if joints_3d.ndim==3:
+ joints_3d = joints_3d[0]
+
+ # Get SMPL parameters, if available
+ has_smpl = np.asarray(annotations['has_smpl'])
+ pose = np.asarray(annotations['pose'])
+ betas = np.asarray(annotations['betas'])
+
+ # Get augmentation parameters
+ flip,pn,rot,sc = self.augm_params()
+
+ # Process image
+ img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
+ img = torch.from_numpy(img).float()
+ # Store image before normalization to use it in visualization
+ transfromed_img = self.normalize_img(img)
+
+ # normalize 3d pose by aligning the wrist as the root (at origin)
+ root_coord = joints_3d[self.root_index,:-1]
+ joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:]
+ # 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
+ joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
+ # 2d pose augmentation
+ joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
+
+ ###################################
+ # Masking percantage
+ # We observe that 0% or 5% works better for 3D hand mesh
+ # We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh
+ mvm_percent = 0.0 # or 0.05
+ ###################################
+
+ mjm_mask = np.ones((21,1))
+ if self.is_train:
+ num_joints = 21
+ pb = np.random.random_sample()
+ masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
+ indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
+ mjm_mask[indices,:] = 0.0
+ mjm_mask = torch.from_numpy(mjm_mask).float()
+
+ mvm_mask = np.ones((195,1))
+ if self.is_train:
+ num_vertices = 195
+ pb = np.random.random_sample()
+ masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
+ indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
+ mvm_mask[indices,:] = 0.0
+ mvm_mask = torch.from_numpy(mvm_mask).float()
+
+ meta_data = {}
+ meta_data['ori_img'] = img
+ meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
+ meta_data['betas'] = torch.from_numpy(betas).float()
+ meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
+ meta_data['has_3d_joints'] = has_3d_joints
+ meta_data['has_smpl'] = has_smpl
+ meta_data['mjm_mask'] = mjm_mask
+ meta_data['mvm_mask'] = mvm_mask
+
+ # Get 2D keypoints and apply augmentation transforms
+ meta_data['has_2d_joints'] = has_2d_joints
+ meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
+
+ meta_data['scale'] = float(sc * scale)
+ meta_data['center'] = np.asarray(center).astype(np.float32)
+
+ return img_key, transfromed_img, meta_data
+
+
+class HandMeshTSVYamlDataset(HandMeshTSVDataset):
+ """ TSVDataset taking a Yaml file for easy function call
+ """
+ def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
+ self.cfg = load_from_yaml_file(yaml_file)
+ self.is_composite = self.cfg.get('composite', False)
+ self.root = op.dirname(yaml_file)
+
+ if self.is_composite==False:
+ img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
+ label_file = find_file_path_in_yaml(self.cfg.get('label', None),
+ self.root)
+ hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
+ linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
+ self.root)
+ else:
+ img_file = self.cfg['img']
+ hw_file = self.cfg['hw']
+ label_file = self.cfg.get('label', None)
+ linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
+ self.root)
+
+ super(HandMeshTSVYamlDataset, self).__init__(
+ args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
diff --git a/src/custom_mesh_graphormer/datasets/human_mesh_tsv.py b/src/custom_mesh_graphormer/datasets/human_mesh_tsv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60e083079f82899c3ecaffc3b8c785ec53e42ae
--- /dev/null
+++ b/src/custom_mesh_graphormer/datasets/human_mesh_tsv.py
@@ -0,0 +1,337 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+import cv2
+import math
+import json
+from PIL import Image
+import os.path as op
+import numpy as np
+import code
+
+from custom_mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
+from custom_mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
+from custom_mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
+import torch
+import torchvision.transforms as transforms
+
+
+class MeshTSVDataset(object):
+ def __init__(self, img_file, label_file=None, hw_file=None,
+ linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
+
+ self.img_file = img_file
+ self.label_file = label_file
+ self.hw_file = hw_file
+ self.linelist_file = linelist_file
+ self.img_tsv = self.get_tsv_file(img_file)
+ self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
+ self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
+
+ if self.is_composite:
+ assert op.isfile(self.linelist_file)
+ self.line_list = [i for i in range(self.hw_tsv.num_rows())]
+ else:
+ self.line_list = load_linelist_file(linelist_file)
+
+ self.cv2_output = cv2_output
+ self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ self.is_train = is_train
+ self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
+ self.noise_factor = 0.4
+ self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]
+ self.img_res = 224
+
+ self.image_keys = self.prepare_image_keys()
+
+ self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
+ 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
+ self.pelvis_index = self.joints_definition.index('Pelvis')
+
+ def get_tsv_file(self, tsv_file):
+ if tsv_file:
+ if self.is_composite:
+ return CompositeTSVFile(tsv_file, self.linelist_file,
+ root=self.root)
+ tsv_path = find_file_path_in_yaml(tsv_file, self.root)
+ return TSVFile(tsv_path)
+
+ def get_valid_tsv(self):
+ # sorted by file size
+ if self.hw_tsv:
+ return self.hw_tsv
+ if self.label_tsv:
+ return self.label_tsv
+
+ def prepare_image_keys(self):
+ tsv = self.get_valid_tsv()
+ return [tsv.get_key(i) for i in range(tsv.num_rows())]
+
+ def prepare_image_key_to_index(self):
+ tsv = self.get_valid_tsv()
+ return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
+
+
+ def augm_params(self):
+ """Get augmentation parameters."""
+ flip = 0 # flipping
+ pn = np.ones(3) # per channel pixel-noise
+ rot = 0 # rotation
+ sc = 1 # scaling
+ if self.is_train:
+ # We flip with probability 1/2
+ if np.random.uniform() <= 0.5:
+ flip = 1
+
+ # Each channel is multiplied with a number
+ # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
+ pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
+
+ # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
+ rot = min(2*self.rot_factor,
+ max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
+
+ # The scale is multiplied with a number
+ # in the area [1-scaleFactor,1+scaleFactor]
+ sc = min(1+self.scale_factor,
+ max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
+ # but it is zero with probability 3/5
+ if np.random.uniform() <= 0.6:
+ rot = 0
+
+ return flip, pn, rot, sc
+
+ def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
+ """Process rgb image and do augmentation."""
+ rgb_img = crop(rgb_img, center, scale,
+ [self.img_res, self.img_res], rot=rot)
+ # flip the image
+ if flip:
+ rgb_img = flip_img(rgb_img)
+ # in the rgb image we add pixel noise in a channel-wise manner
+ rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
+ rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
+ rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
+ # (3,224,224),float,[0,1]
+ rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
+ return rgb_img
+
+ def j2d_processing(self, kp, center, scale, r, f):
+ """Process gt 2D keypoints and apply all augmentation transforms."""
+ nparts = kp.shape[0]
+ for i in range(nparts):
+ kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
+ [self.img_res, self.img_res], rot=r)
+ # convert to normalized coordinates
+ kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
+ # flip the x coordinates
+ if f:
+ kp = flip_kp(kp)
+ kp = kp.astype('float32')
+ return kp
+
+ def j3d_processing(self, S, r, f):
+ """Process gt 3D keypoints and apply all augmentation transforms."""
+ # in-plane rotation
+ rot_mat = np.eye(3)
+ if not r == 0:
+ rot_rad = -r * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
+ # flip the x coordinates
+ if f:
+ S = flip_kp(S)
+ S = S.astype('float32')
+ return S
+
+ def pose_processing(self, pose, r, f):
+ """Process SMPL theta parameters and apply all augmentation transforms."""
+ # rotation or the pose parameters
+ pose = pose.astype('float32')
+ pose[:3] = rot_aa(pose[:3], r)
+ # flip the pose parameters
+ if f:
+ pose = flip_pose(pose)
+ # (72),float
+ pose = pose.astype('float32')
+ return pose
+
+ def get_line_no(self, idx):
+ return idx if self.line_list is None else self.line_list[idx]
+
+ def get_image(self, idx):
+ line_no = self.get_line_no(idx)
+ row = self.img_tsv[line_no]
+ # use -1 to support old format with multiple columns.
+ cv2_im = img_from_base64(row[-1])
+ if self.cv2_output:
+ return cv2_im.astype(np.float32, copy=True)
+ cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
+
+ return cv2_im
+
+ def get_annotations(self, idx):
+ line_no = self.get_line_no(idx)
+ if self.label_tsv is not None:
+ row = self.label_tsv[line_no]
+ annotations = json.loads(row[1])
+ return annotations
+ else:
+ return []
+
+ def get_target_from_annotations(self, annotations, img_size, idx):
+ # This function will be overwritten by each dataset to
+ # decode the labels to specific formats for each task.
+ return annotations
+
+
+ def get_img_info(self, idx):
+ if self.hw_tsv is not None:
+ line_no = self.get_line_no(idx)
+ row = self.hw_tsv[line_no]
+ try:
+ # json string format with "height" and "width" being the keys
+ return json.loads(row[1])[0]
+ except ValueError:
+ # list of strings representing height and width in order
+ hw_str = row[1].split(' ')
+ hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
+ return hw_dict
+
+ def get_img_key(self, idx):
+ line_no = self.get_line_no(idx)
+ # based on the overhead of reading each row.
+ if self.hw_tsv:
+ return self.hw_tsv[line_no][0]
+ elif self.label_tsv:
+ return self.label_tsv[line_no][0]
+ else:
+ return self.img_tsv[line_no][0]
+
+ def __len__(self):
+ if self.line_list is None:
+ return self.img_tsv.num_rows()
+ else:
+ return len(self.line_list)
+
+ def __getitem__(self, idx):
+
+ img = self.get_image(idx)
+ img_key = self.get_img_key(idx)
+ annotations = self.get_annotations(idx)
+
+ annotations = annotations[0]
+ center = annotations['center']
+ scale = annotations['scale']
+ has_2d_joints = annotations['has_2d_joints']
+ has_3d_joints = annotations['has_3d_joints']
+ joints_2d = np.asarray(annotations['2d_joints'])
+ joints_3d = np.asarray(annotations['3d_joints'])
+
+ if joints_2d.ndim==3:
+ joints_2d = joints_2d[0]
+ if joints_3d.ndim==3:
+ joints_3d = joints_3d[0]
+
+ # Get SMPL parameters, if available
+ has_smpl = np.asarray(annotations['has_smpl'])
+ pose = np.asarray(annotations['pose'])
+ betas = np.asarray(annotations['betas'])
+
+ try:
+ gender = annotations['gender']
+ except KeyError:
+ gender = 'none'
+
+ # Get augmentation parameters
+ flip,pn,rot,sc = self.augm_params()
+
+ # Process image
+ img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
+ img = torch.from_numpy(img).float()
+ # Store image before normalization to use it in visualization
+ transfromed_img = self.normalize_img(img)
+
+ # normalize 3d pose by aligning the pelvis as the root (at origin)
+ root_pelvis = joints_3d[self.pelvis_index,:-1]
+ joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
+ # 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
+ joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
+ # 2d pose augmentation
+ joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
+
+ ###################################
+ # Masking percantage
+ # We observe that 30% works better for human body mesh. Further details are reported in the paper.
+ mvm_percent = 0.3
+ ###################################
+
+ mjm_mask = np.ones((14,1))
+ if self.is_train:
+ num_joints = 14
+ pb = np.random.random_sample()
+ masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
+ indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
+ mjm_mask[indices,:] = 0.0
+ mjm_mask = torch.from_numpy(mjm_mask).float()
+
+ mvm_mask = np.ones((431,1))
+ if self.is_train:
+ num_vertices = 431
+ pb = np.random.random_sample()
+ masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
+ indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
+ mvm_mask[indices,:] = 0.0
+ mvm_mask = torch.from_numpy(mvm_mask).float()
+
+ meta_data = {}
+ meta_data['ori_img'] = img
+ meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
+ meta_data['betas'] = torch.from_numpy(betas).float()
+ meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
+ meta_data['has_3d_joints'] = has_3d_joints
+ meta_data['has_smpl'] = has_smpl
+
+ meta_data['mjm_mask'] = mjm_mask
+ meta_data['mvm_mask'] = mvm_mask
+
+ # Get 2D keypoints and apply augmentation transforms
+ meta_data['has_2d_joints'] = has_2d_joints
+ meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
+ meta_data['scale'] = float(sc * scale)
+ meta_data['center'] = np.asarray(center).astype(np.float32)
+ meta_data['gender'] = gender
+ return img_key, transfromed_img, meta_data
+
+
+
+class MeshTSVYamlDataset(MeshTSVDataset):
+ """ TSVDataset taking a Yaml file for easy function call
+ """
+ def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
+ self.cfg = load_from_yaml_file(yaml_file)
+ self.is_composite = self.cfg.get('composite', False)
+ self.root = op.dirname(yaml_file)
+
+ if self.is_composite==False:
+ img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
+ label_file = find_file_path_in_yaml(self.cfg.get('label', None),
+ self.root)
+ hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
+ linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
+ self.root)
+ else:
+ img_file = self.cfg['img']
+ hw_file = self.cfg['hw']
+ label_file = self.cfg.get('label', None)
+ linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
+ self.root)
+
+ super(MeshTSVYamlDataset, self).__init__(
+ img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
diff --git a/src/custom_mesh_graphormer/modeling/__init__.py b/src/custom_mesh_graphormer/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_mesh_graphormer/modeling/_gcnn.py b/src/custom_mesh_graphormer/modeling/_gcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2baaa97add7357d16d0cb5360e37be35eccf2dd
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/_gcnn.py
@@ -0,0 +1,184 @@
+from __future__ import division
+import torch
+import torch.nn.functional as F
+import numpy as np
+import scipy.sparse
+import math
+from pathlib import Path
+data_path = Path(__file__).parent / "data"
+
+from comfy.model_management import get_torch_device
+from wrapper_for_mps import sparse_to_dense
+device = get_torch_device()
+
+class SparseMM(torch.autograd.Function):
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
+ """
+ @staticmethod
+ def forward(ctx, sparse, dense):
+ ctx.req_grad = dense.requires_grad
+ ctx.save_for_backward(sparse)
+ return torch.matmul(sparse, dense)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ sparse, = ctx.saved_tensors
+ if ctx.req_grad:
+ grad_input = torch.matmul(sparse.t(), grad_output)
+ return None, grad_input
+
+def spmm(sparse, dense):
+ sparse = sparse.to(device)
+ dense = dense.to(device)
+ return SparseMM.apply(sparse, dense)
+
+
+def gelu(x):
+ """Implementation of the gelu activation function.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+class BertLayerNorm(torch.nn.Module):
+ def __init__(self, hidden_size, eps=1e-12):
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
+ """
+ super(BertLayerNorm, self).__init__()
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
+ self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ u = x.mean(-1, keepdim=True)
+ s = (x - u).pow(2).mean(-1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+ return self.weight * x + self.bias
+
+
+class GraphResBlock(torch.nn.Module):
+ """
+ Graph Residual Block similar to the Bottleneck Residual Block in ResNet
+ """
+ def __init__(self, in_channels, out_channels, mesh_type='body'):
+ super(GraphResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.lin1 = GraphLinear(in_channels, out_channels // 2)
+ self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
+ self.lin2 = GraphLinear(out_channels // 2, out_channels)
+ self.skip_conv = GraphLinear(in_channels, out_channels)
+ # print('Use BertLayerNorm in GraphResBlock')
+ self.pre_norm = BertLayerNorm(in_channels)
+ self.norm1 = BertLayerNorm(out_channels // 2)
+ self.norm2 = BertLayerNorm(out_channels // 2)
+
+ def forward(self, x):
+ trans_y = F.relu(self.pre_norm(x)).transpose(1,2)
+ y = self.lin1(trans_y).transpose(1,2)
+
+ y = F.relu(self.norm1(y))
+ y = self.conv(y)
+
+ trans_y = F.relu(self.norm2(y)).transpose(1,2)
+ y = self.lin2(trans_y).transpose(1,2)
+
+ z = x+y
+
+ return z
+
+# class GraphResBlock(torch.nn.Module):
+# """
+# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
+# """
+# def __init__(self, in_channels, out_channels, mesh_type='body'):
+# super(GraphResBlock, self).__init__()
+# self.in_channels = in_channels
+# self.out_channels = out_channels
+# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
+# print('Use BertLayerNorm and GeLU in GraphResBlock')
+# self.norm = BertLayerNorm(self.out_channels)
+# def forward(self, x):
+# y = self.conv(x)
+# y = self.norm(y)
+# y = gelu(y)
+# z = x+y
+# return z
+
+class GraphLinear(torch.nn.Module):
+ """
+ Generalization of 1x1 convolutions on Graphs
+ """
+ def __init__(self, in_channels, out_channels):
+ super(GraphLinear, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels))
+ self.b = torch.nn.Parameter(torch.FloatTensor(out_channels))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ w_stdv = 1 / (self.in_channels * self.out_channels)
+ self.W.data.uniform_(-w_stdv, w_stdv)
+ self.b.data.uniform_(-w_stdv, w_stdv)
+
+ def forward(self, x):
+ return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
+
+class GraphConvolution(torch.nn.Module):
+ """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
+ def __init__(self, in_features, out_features, mesh='body', bias=True):
+ super(GraphConvolution, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ if mesh=='body':
+ adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt')
+ adj_mat_value = torch.load(data_path / 'smpl_431_adjmat_values.pt')
+ adj_mat_size = torch.load(data_path / 'smpl_431_adjmat_size.pt')
+ elif mesh=='hand':
+ adj_indices = torch.load(data_path / 'mano_195_adjmat_indices.pt')
+ adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt')
+ adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt')
+
+ self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device)
+
+ self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
+ if bias:
+ self.bias = torch.nn.Parameter(torch.FloatTensor(out_features))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # stdv = 1. / math.sqrt(self.weight.size(1))
+ stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.uniform_(-stdv, stdv)
+
+ def forward(self, x):
+ if x.ndimension() == 2:
+ support = torch.matmul(x, self.weight)
+ output = torch.matmul(self.adjmat, support)
+ if self.bias is not None:
+ output = output + self.bias
+ return output
+ else:
+ output = []
+ for i in range(x.shape[0]):
+ support = torch.matmul(x[i], self.weight)
+ # output.append(torch.matmul(self.adjmat, support))
+ output.append(spmm(self.adjmat, support))
+ output = torch.stack(output, dim=0)
+ if self.bias is not None:
+ output = output + self.bias
+ return output
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + str(self.in_features) + ' -> ' \
+ + str(self.out_features) + ')'
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/_mano.py b/src/custom_mesh_graphormer/modeling/_mano.py
new file mode 100644
index 0000000000000000000000000000000000000000..8374b6300ce671ebe1ba6070c6002cdf07ad396c
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/_mano.py
@@ -0,0 +1,184 @@
+"""
+This file contains the MANO defination and mesh sampling operations for MANO mesh
+
+Adapted from opensource projects
+MANOPTH (https://github.com/hassony2/manopth)
+Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
+GraphCMR (https://github.com/nkolot/GraphCMR/)
+"""
+
+from __future__ import division
+import numpy as np
+import torch
+import torch.nn as nn
+import os.path as osp
+import json
+import code
+from custom_manopth.manolayer import ManoLayer
+import scipy.sparse
+import custom_mesh_graphormer.modeling.data.config as cfg
+from pathlib import Path
+
+from comfy.model_management import get_torch_device
+from wrapper_for_mps import sparse_to_dense
+device = get_torch_device()
+
+class MANO(nn.Module):
+ def __init__(self):
+ super(MANO, self).__init__()
+
+ self.mano_dir = str(Path(__file__).parent / "data")
+ self.layer = self.get_layer()
+ self.vertex_num = 778
+ self.face = self.layer.th_faces.numpy()
+ self.joint_regressor = self.layer.th_J_regressor.numpy()
+
+ self.joint_num = 21
+ self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
+ self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
+ self.root_joint_idx = self.joints_name.index('Wrist')
+
+ # add fingertips to joint_regressor
+ self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
+ thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
+ indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
+ middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
+ ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
+ pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
+ self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
+ self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
+ joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
+ self.register_buffer('joint_regressor_torch', joint_regressor_torch)
+
+ def get_layer(self):
+ return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
+
+ def get_3d_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 778, 3)
+ Output:
+ 3D joints: size = (B, 21, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
+ return joints
+
+
+class SparseMM(torch.autograd.Function):
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
+ """
+ @staticmethod
+ def forward(ctx, sparse, dense):
+ ctx.req_grad = dense.requires_grad
+ ctx.save_for_backward(sparse)
+ return torch.matmul(sparse, dense)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ sparse, = ctx.saved_tensors
+ if ctx.req_grad:
+ grad_input = torch.matmul(sparse.t(), grad_output)
+ return None, grad_input
+
+def spmm(sparse, dense):
+ sparse = sparse.to(device)
+ dense = dense.to(device)
+ return SparseMM.apply(sparse, dense)
+
+
+def scipy_to_pytorch(A, U, D):
+ """Convert scipy sparse matrices to pytorch sparse matrix."""
+ ptU = []
+ ptD = []
+
+ for i in range(len(U)):
+ u = scipy.sparse.coo_matrix(U[i])
+ i = torch.LongTensor(np.array([u.row, u.col]))
+ v = torch.FloatTensor(u.data)
+ ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
+
+ for i in range(len(D)):
+ d = scipy.sparse.coo_matrix(D[i])
+ i = torch.LongTensor(np.array([d.row, d.col]))
+ v = torch.FloatTensor(d.data)
+ ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
+
+ return ptU, ptD
+
+
+def adjmat_sparse(adjmat, nsize=1):
+ """Create row-normalized sparse graph adjacency matrix."""
+ adjmat = scipy.sparse.csr_matrix(adjmat)
+ if nsize > 1:
+ orig_adjmat = adjmat.copy()
+ for _ in range(1, nsize):
+ adjmat = adjmat * orig_adjmat
+ adjmat.data = np.ones_like(adjmat.data)
+ for i in range(adjmat.shape[0]):
+ adjmat[i,i] = 1
+ num_neighbors = np.array(1 / adjmat.sum(axis=-1))
+ adjmat = adjmat.multiply(num_neighbors)
+ adjmat = scipy.sparse.coo_matrix(adjmat)
+ row = adjmat.row
+ col = adjmat.col
+ data = adjmat.data
+ i = torch.LongTensor(np.array([row, col]))
+ v = torch.from_numpy(data).float()
+ adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
+ return adjmat
+
+def get_graph_params(filename, nsize=1):
+ """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
+ data = np.load(filename, encoding='latin1', allow_pickle=True)
+ A = data['A']
+ U = data['U']
+ D = data['D']
+ U, D = scipy_to_pytorch(A, U, D)
+ A = [adjmat_sparse(a, nsize=nsize) for a in A]
+ return A, U, D
+
+
+class Mesh(object):
+ """Mesh object that is used for handling certain graph operations."""
+ def __init__(self, filename=cfg.MANO_sampling_matrix,
+ num_downsampling=1, nsize=1, device=torch.device('cuda')):
+ self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
+ # self._A = [a.to(device) for a in self._A]
+ self._U = [u.to(device) for u in self._U]
+ self._D = [d.to(device) for d in self._D]
+ self.num_downsampling = num_downsampling
+
+ def downsample(self, x, n1=0, n2=None):
+ """Downsample mesh."""
+ if n2 is None:
+ n2 = self.num_downsampling
+ if x.ndimension() < 3:
+ for i in range(n1, n2):
+ x = spmm(self._D[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in range(n1, n2):
+ y = spmm(self._D[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
+
+ def upsample(self, x, n1=1, n2=0):
+ """Upsample mesh."""
+ if x.ndimension() < 3:
+ for i in reversed(range(n2, n1)):
+ x = spmm(self._U[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in reversed(range(n2, n1)):
+ y = spmm(self._U[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
diff --git a/src/custom_mesh_graphormer/modeling/_smpl.py b/src/custom_mesh_graphormer/modeling/_smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5b46413005e564b0191ea8e8bb9b0641ea97766
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/_smpl.py
@@ -0,0 +1,283 @@
+"""
+This file contains the definition of the SMPL model
+
+It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
+"""
+from __future__ import division
+
+import torch
+import torch.nn as nn
+import numpy as np
+import scipy.sparse
+try:
+ import cPickle as pickle
+except ImportError:
+ import pickle
+
+from custom_mesh_graphormer.utils.geometric_layers import rodrigues
+import custom_mesh_graphormer.modeling.data.config as cfg
+
+from comfy.model_management import get_torch_device
+from wrapper_for_mps import sparse_to_dense
+device = get_torch_device()
+
+class SMPL(nn.Module):
+
+ def __init__(self, gender='neutral'):
+ super(SMPL, self).__init__()
+
+ if gender=='m':
+ model_file=cfg.SMPL_Male
+ elif gender=='f':
+ model_file=cfg.SMPL_Female
+ else:
+ model_file=cfg.SMPL_FILE
+
+ smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
+ J_regressor = smpl_model['J_regressor'].tocoo()
+ row = J_regressor.row
+ col = J_regressor.col
+ data = J_regressor.data
+ i = torch.LongTensor([row, col])
+ v = torch.FloatTensor(data)
+ J_regressor_shape = [24, 6890]
+ self.register_buffer('J_regressor', torch.sparse_coo_tensor(i, v, J_regressor_shape).to_dense())
+ self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
+ self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
+ self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
+ self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'])))
+ self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
+ self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
+ id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
+ self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
+
+ self.pose_shape = [24, 3]
+ self.beta_shape = [10]
+ self.translation_shape = [3]
+
+ self.pose = torch.zeros(self.pose_shape)
+ self.beta = torch.zeros(self.beta_shape)
+ self.translation = torch.zeros(self.translation_shape)
+
+ self.verts = None
+ self.J = None
+ self.R = None
+
+ J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float()
+ self.register_buffer('J_regressor_extra', J_regressor_extra)
+ self.joints_idx = cfg.JOINTS_IDX
+
+ J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float()
+ self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
+
+
+ def forward(self, pose, beta):
+ device = pose.device
+ batch_size = pose.shape[0]
+ v_template = self.v_template[None, :]
+ shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
+ beta = beta[:, :, None]
+ v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
+ # batched sparse matmul not supported in pytorch
+ J = []
+ for i in range(batch_size):
+ J.append(torch.matmul(self.J_regressor, v_shaped[i]))
+ J = torch.stack(J, dim=0)
+ # input it rotmat: (bs,24,3,3)
+ if pose.ndimension() == 4:
+ R = pose
+ # input it rotmat: (bs,72)
+ elif pose.ndimension() == 2:
+ pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
+ R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
+ R = R.view(batch_size, 24, 3, 3)
+ I_cube = torch.eye(3)[None, None, :].to(device)
+ # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
+ lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
+ posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
+ v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
+ J_ = J.clone()
+ J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
+ G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
+ pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
+ G_ = torch.cat([G_, pad_row], dim=2)
+ G = [G_[:, 0].clone()]
+ for i in range(1, 24):
+ G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
+ G = torch.stack(G, dim=1)
+
+ rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
+ zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
+ rest = torch.cat([zeros, rest], dim=-1)
+ rest = torch.matmul(G, rest)
+ G = G - rest
+ T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
+ rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
+ v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
+ return v
+
+ def get_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 6890, 3)
+ Output:
+ 3D joints: size = (B, 38, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
+ joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
+ joints = torch.cat((joints, joints_extra), dim=1)
+ joints = joints[:, cfg.JOINTS_IDX]
+ return joints
+
+ def get_h36m_joints(self, vertices):
+ """
+ This method is used to get the joint locations from the SMPL mesh
+ Input:
+ vertices: size = (B, 6890, 3)
+ Output:
+ 3D joints: size = (B, 24, 3)
+ """
+ joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
+ return joints
+
+class SparseMM(torch.autograd.Function):
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
+ """
+ @staticmethod
+ def forward(ctx, sparse, dense):
+ ctx.req_grad = dense.requires_grad
+ ctx.save_for_backward(sparse)
+ return torch.matmul(sparse, dense)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ sparse, = ctx.saved_tensors
+ if ctx.req_grad:
+ grad_input = torch.matmul(sparse.t(), grad_output)
+ return None, grad_input
+
+def spmm(sparse, dense):
+ sparse = sparse.to(device)
+ dense = dense.to(device)
+ return SparseMM.apply(sparse, dense)
+
+
+def scipy_to_pytorch(A, U, D):
+ """Convert scipy sparse matrices to pytorch sparse matrix."""
+ ptU = []
+ ptD = []
+
+ for i in range(len(U)):
+ u = scipy.sparse.coo_matrix(U[i])
+ i = torch.LongTensor(np.array([u.row, u.col]))
+ v = torch.FloatTensor(u.data)
+ ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
+
+ for i in range(len(D)):
+ d = scipy.sparse.coo_matrix(D[i])
+ i = torch.LongTensor(np.array([d.row, d.col]))
+ v = torch.FloatTensor(d.data)
+ ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
+
+ return ptU, ptD
+
+
+def adjmat_sparse(adjmat, nsize=1):
+ """Create row-normalized sparse graph adjacency matrix."""
+ adjmat = scipy.sparse.csr_matrix(adjmat)
+ if nsize > 1:
+ orig_adjmat = adjmat.copy()
+ for _ in range(1, nsize):
+ adjmat = adjmat * orig_adjmat
+ adjmat.data = np.ones_like(adjmat.data)
+ for i in range(adjmat.shape[0]):
+ adjmat[i,i] = 1
+ num_neighbors = np.array(1 / adjmat.sum(axis=-1))
+ adjmat = adjmat.multiply(num_neighbors)
+ adjmat = scipy.sparse.coo_matrix(adjmat)
+ row = adjmat.row
+ col = adjmat.col
+ data = adjmat.data
+ i = torch.LongTensor(np.array([row, col]))
+ v = torch.from_numpy(data).float()
+ adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
+ return adjmat
+
+def get_graph_params(filename, nsize=1):
+ """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
+ data = np.load(filename, encoding='latin1', allow_pickle=True)
+ A = data['A']
+ U = data['U']
+ D = data['D']
+ U, D = scipy_to_pytorch(A, U, D)
+ A = [adjmat_sparse(a, nsize=nsize) for a in A]
+ return A, U, D
+
+
+class Mesh(object):
+ """Mesh object that is used for handling certain graph operations."""
+ def __init__(self, filename=cfg.SMPL_sampling_matrix,
+ num_downsampling=1, nsize=1, device=torch.device('cuda')):
+ self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
+ # self._A = [a.to(device) for a in self._A]
+ self._U = [u.to(device) for u in self._U]
+ self._D = [d.to(device) for d in self._D]
+ self.num_downsampling = num_downsampling
+
+ # load template vertices from SMPL and normalize them
+ smpl = SMPL()
+ ref_vertices = smpl.v_template
+ center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
+ ref_vertices -= center
+ ref_vertices /= ref_vertices.abs().max().item()
+
+ self._ref_vertices = ref_vertices.to(device)
+ self.faces = smpl.faces.int().to(device)
+
+ # @property
+ # def adjmat(self):
+ # """Return the graph adjacency matrix at the specified subsampling level."""
+ # return self._A[self.num_downsampling].float()
+
+ @property
+ def ref_vertices(self):
+ """Return the template vertices at the specified subsampling level."""
+ ref_vertices = self._ref_vertices
+ for i in range(self.num_downsampling):
+ ref_vertices = torch.spmm(self._D[i], ref_vertices)
+ return ref_vertices
+
+ def downsample(self, x, n1=0, n2=None):
+ """Downsample mesh."""
+ if n2 is None:
+ n2 = self.num_downsampling
+ if x.ndimension() < 3:
+ for i in range(n1, n2):
+ x = spmm(self._D[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in range(n1, n2):
+ y = spmm(self._D[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
+
+ def upsample(self, x, n1=1, n2=0):
+ """Upsample mesh."""
+ if x.ndimension() < 3:
+ for i in reversed(range(n2, n1)):
+ x = spmm(self._U[i], x)
+ elif x.ndimension() == 3:
+ out = []
+ for i in range(x.shape[0]):
+ y = x[i]
+ for j in reversed(range(n2, n1)):
+ y = spmm(self._U[j], y)
+ out.append(y)
+ x = torch.stack(out, dim=0)
+ return x
diff --git a/src/custom_mesh_graphormer/modeling/bert/__init__.py b/src/custom_mesh_graphormer/modeling/bert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..197c5b92786fcb15d754555c2cd4d3531f98a0e0
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/__init__.py
@@ -0,0 +1,17 @@
+__version__ = "1.0.0"
+
+from .modeling_bert import (BertConfig, BertModel,
+ load_tf_weights_in_bert)
+
+from .modeling_graphormer import Graphormer
+
+from .e2e_body_network import Graphormer_Body_Network
+
+from .e2e_hand_network import Graphormer_Hand_Network
+
+CONFIG_NAME = "config.json"
+
+from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME,
+ PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
+
+from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE)
diff --git a/src/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json b/src/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..79276673252f15cea400800731e0d4e3d3cba64f
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json
@@ -0,0 +1,16 @@
+{
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "type_vocab_size": 2,
+ "vocab_size": 30522
+}
diff --git a/src/custom_mesh_graphormer/modeling/bert/e2e_body_network.py b/src/custom_mesh_graphormer/modeling/bert/e2e_body_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..f570ec58f66c20802898103555949e74ff56930c
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/e2e_body_network.py
@@ -0,0 +1,103 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+import torch
+import custom_mesh_graphormer.modeling.data.config as cfg
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+class Graphormer_Body_Network(torch.nn.Module):
+ '''
+ End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
+ '''
+ def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
+ super(Graphormer_Body_Network, self).__init__()
+ self.config = config
+ self.config.device = device
+ self.backbone = backbone
+ self.trans_encoder = trans_encoder
+ self.upsampling = torch.nn.Linear(431, 1723)
+ self.upsampling2 = torch.nn.Linear(1723, 6890)
+ self.cam_param_fc = torch.nn.Linear(3, 1)
+ self.cam_param_fc2 = torch.nn.Linear(431, 250)
+ self.cam_param_fc3 = torch.nn.Linear(250, 3)
+ self.grid_feat_dim = torch.nn.Linear(1024, 2051)
+
+
+ def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False):
+ batch_size = images.size(0)
+ # Generate T-pose template mesh
+ template_pose = torch.zeros((1,72))
+ template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
+ template_pose = template_pose.to(device)
+ template_betas = torch.zeros((1,10)).to(device)
+ template_vertices = smpl(template_pose, template_betas)
+
+ # template mesh simplification
+ template_vertices_sub = mesh_sampler.downsample(template_vertices)
+ template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2)
+
+ # template mesh-to-joint regression
+ template_3d_joints = smpl.get_h36m_joints(template_vertices)
+ template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
+ template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:]
+ num_joints = template_3d_joints.shape[1]
+
+ # normalize
+ template_3d_joints = template_3d_joints - template_pelvis[:, None, :]
+ template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :]
+
+ # concatinate template joints and template vertices, and then duplicate to batch size
+ ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1)
+ ref_vertices = ref_vertices.expand(batch_size, -1, -1)
+
+ # extract grid features and global image features using a CNN backbone
+ image_feat, grid_feat = self.backbone(images)
+ # concatinate image feat and 3d mesh template
+ image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
+ # process grid features
+ grid_feat = torch.flatten(grid_feat, start_dim=2)
+ grid_feat = grid_feat.transpose(1,2)
+ grid_feat = self.grid_feat_dim(grid_feat)
+ # concatinate image feat and template mesh to form the joint/vertex queries
+ features = torch.cat([ref_vertices, image_feat], dim=2)
+ # prepare input tokens including joint/vertex queries and grid features
+ features = torch.cat([features, grid_feat],dim=1)
+
+ if is_train==True:
+ # apply mask vertex/joint modeling
+ # meta_masks is a tensor of all the masks, randomly generated in dataloader
+ # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
+ special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
+ features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
+
+ # forward pass
+ if self.config.output_attentions==True:
+ features, hidden_states, att = self.trans_encoder(features)
+ else:
+ features = self.trans_encoder(features)
+
+ pred_3d_joints = features[:,:num_joints,:]
+ pred_vertices_sub2 = features[:,num_joints:-49,:]
+
+ # learn camera parameters
+ x = self.cam_param_fc(pred_vertices_sub2)
+ x = x.transpose(1,2)
+ x = self.cam_param_fc2(x)
+ x = self.cam_param_fc3(x)
+ cam_param = x.transpose(1,2)
+ cam_param = cam_param.squeeze()
+
+ temp_transpose = pred_vertices_sub2.transpose(1,2)
+ pred_vertices_sub = self.upsampling(temp_transpose)
+ pred_vertices_full = self.upsampling2(pred_vertices_sub)
+ pred_vertices_sub = pred_vertices_sub.transpose(1,2)
+ pred_vertices_full = pred_vertices_full.transpose(1,2)
+
+ if self.config.output_attentions==True:
+ return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att
+ else:
+ return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py b/src/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..57cba0a830c2ca975048fd1a3de9f8341196ba43
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py
@@ -0,0 +1,94 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+import torch
+import custom_mesh_graphormer.modeling.data.config as cfg
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+class Graphormer_Hand_Network(torch.nn.Module):
+ '''
+ End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
+ '''
+ def __init__(self, args, config, backbone, trans_encoder):
+ super(Graphormer_Hand_Network, self).__init__()
+ self.config = config
+ self.backbone = backbone
+ self.trans_encoder = trans_encoder
+ self.upsampling = torch.nn.Linear(195, 778)
+ self.cam_param_fc = torch.nn.Linear(3, 1)
+ self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
+ self.cam_param_fc3 = torch.nn.Linear(150, 3)
+ self.grid_feat_dim = torch.nn.Linear(1024, 2051)
+
+ def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False):
+ batch_size = images.size(0)
+ # Generate T-pose template mesh
+ template_pose = torch.zeros((1,48))
+ template_pose = template_pose.to(device)
+ template_betas = torch.zeros((1,10)).to(device)
+ template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
+ template_vertices = template_vertices/1000.0
+ template_3d_joints = template_3d_joints/1000.0
+
+ template_vertices_sub = mesh_sampler.downsample(template_vertices)
+
+ # normalize
+ template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
+ template_3d_joints = template_3d_joints - template_root[:, None, :]
+ template_vertices = template_vertices - template_root[:, None, :]
+ template_vertices_sub = template_vertices_sub - template_root[:, None, :]
+ num_joints = template_3d_joints.shape[1]
+
+ # concatinate template joints and template vertices, and then duplicate to batch size
+ ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1)
+ ref_vertices = ref_vertices.expand(batch_size, -1, -1)
+
+ # extract grid features and global image features using a CNN backbone
+ image_feat, grid_feat = self.backbone(images)
+ # concatinate image feat and mesh template
+ image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
+ # process grid features
+ grid_feat = torch.flatten(grid_feat, start_dim=2)
+ grid_feat = grid_feat.transpose(1,2)
+ grid_feat = self.grid_feat_dim(grid_feat)
+ # concatinate image feat and template mesh to form the joint/vertex queries
+ features = torch.cat([ref_vertices, image_feat], dim=2)
+ # prepare input tokens including joint/vertex queries and grid features
+ features = torch.cat([features, grid_feat],dim=1)
+
+ if is_train==True:
+ # apply mask vertex/joint modeling
+ # meta_masks is a tensor of all the masks, randomly generated in dataloader
+ # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
+ special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
+ features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
+
+ # forward pass
+ if self.config.output_attentions==True:
+ features, hidden_states, att = self.trans_encoder(features)
+ else:
+ features = self.trans_encoder(features)
+
+ pred_3d_joints = features[:,:num_joints,:]
+ pred_vertices_sub = features[:,num_joints:-49,:]
+
+ # learn camera parameters
+ x = self.cam_param_fc(features[:,:-49,:])
+ x = x.transpose(1,2)
+ x = self.cam_param_fc2(x)
+ x = self.cam_param_fc3(x)
+ cam_param = x.transpose(1,2)
+ cam_param = cam_param.squeeze()
+
+ temp_transpose = pred_vertices_sub.transpose(1,2)
+ pred_vertices = self.upsampling(temp_transpose)
+ pred_vertices = pred_vertices.transpose(1,2)
+
+ if self.config.output_attentions==True:
+ return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att
+ else:
+ return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/bert/file_utils.py b/src/custom_mesh_graphormer/modeling/bert/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b26c2ba2e0a5371ee599dc5253b160fa04ba510
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/file_utils.py
@@ -0,0 +1 @@
+from transformers.file_utils import *
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/bert/modeling_bert.py b/src/custom_mesh_graphormer/modeling/bert/modeling_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b6a1eeeb6d1ba9d4fedff7f48ac55eaa0072196
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/modeling_bert.py
@@ -0,0 +1,5 @@
+from transformers.models.bert import modeling_bert
+
+for symbol in dir(modeling_bert):
+ if not symbol.startswith("_"):
+ globals()[symbol] = getattr(modeling_bert, symbol)
diff --git a/src/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py b/src/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7c6e2625afaee69f4f9350374443d4f21e881a1
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py
@@ -0,0 +1,328 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import logging
+import math
+import os
+import code
+import torch
+from torch import nn
+from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
+import custom_mesh_graphormer.modeling.data.config as cfg
+from custom_mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock
+from .modeling_utils import prune_linear_layer
+LayerNormClass = torch.nn.LayerNorm
+BertLayerNorm = torch.nn.LayerNorm
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super(BertSelfAttention, self).__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.output_attentions = config.output_attentions
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask, head_mask=None,
+ history_state=None):
+ if history_state is not None:
+ x_states = torch.cat([history_state, hidden_states], dim=1)
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(x_states)
+ mixed_value_layer = self.value(x_states)
+ else:
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
+ return outputs
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super(BertAttention, self).__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
+ for head in heads:
+ mask[head] = 0
+ mask = mask.view(-1).contiguous().eq(1)
+ index = torch.arange(len(mask))[mask].long()
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+ # Update hyper params
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+
+ def forward(self, input_tensor, attention_mask, head_mask=None,
+ history_state=None):
+ self_outputs = self.self(input_tensor, attention_mask, head_mask,
+ history_state)
+ attention_output = self.output(self_outputs[0], input_tensor)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class GraphormerLayer(nn.Module):
+ def __init__(self, config):
+ super(GraphormerLayer, self).__init__()
+ self.attention = BertAttention(config)
+ self.has_graph_conv = config.graph_conv
+ self.mesh_type = config.mesh_type
+
+ if self.has_graph_conv == True:
+ self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
+
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def MHA_GCN(self, hidden_states, attention_mask, head_mask=None,
+ history_state=None):
+ attention_outputs = self.attention(hidden_states, attention_mask,
+ head_mask, history_state)
+ attention_output = attention_outputs[0]
+
+ if self.has_graph_conv==True:
+ if self.mesh_type == 'body':
+ joints = attention_output[:,0:14,:]
+ vertices = attention_output[:,14:-49,:]
+ img_tokens = attention_output[:,-49:,:]
+
+ elif self.mesh_type == 'hand':
+ joints = attention_output[:,0:21,:]
+ vertices = attention_output[:,21:-49,:]
+ img_tokens = attention_output[:,-49:,:]
+
+ vertices = self.graph_conv(vertices)
+ joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1)
+ else:
+ joints_vertices = attention_output
+
+ intermediate_output = self.intermediate(joints_vertices)
+ layer_output = self.output(intermediate_output, joints_vertices)
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
+ return outputs
+
+ def forward(self, hidden_states, attention_mask, head_mask=None,
+ history_state=None):
+ return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state)
+
+
+class GraphormerEncoder(nn.Module):
+ def __init__(self, config):
+ super(GraphormerEncoder, self).__init__()
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(self, hidden_states, attention_mask, head_mask=None,
+ encoder_history_states=None):
+ all_hidden_states = ()
+ all_attentions = ()
+ for i, layer_module in enumerate(self.layer):
+ if self.output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ history_state = None if encoder_history_states is None else encoder_history_states[i]
+ layer_outputs = layer_module(
+ hidden_states, attention_mask, head_mask[i],
+ history_state)
+ hidden_states = layer_outputs[0]
+
+ if self.output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ # Add last layer
+ if self.output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ outputs = (hidden_states,)
+ if self.output_hidden_states:
+ outputs = outputs + (all_hidden_states,)
+ if self.output_attentions:
+ outputs = outputs + (all_attentions,)
+
+ return outputs # outputs, (hidden states), (attentions)
+
+class EncoderBlock(BertPreTrainedModel):
+ def __init__(self, config):
+ super(EncoderBlock, self).__init__(config)
+ self.config = config
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = GraphormerEncoder(config)
+ self.pooler = BertPooler(config)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.img_dim = config.img_feature_dim
+
+ try:
+ self.use_img_layernorm = config.use_img_layernorm
+ except:
+ self.use_img_layernorm = None
+
+ self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if self.use_img_layernorm:
+ self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps)
+
+
+ def _prune_heads(self, heads_to_prune):
+ """ Prunes heads of the model.
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ See base class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
+ position_ids=None, head_mask=None):
+
+ batch_size = len(img_feats)
+ seq_length = len(img_feats[0])
+ input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device)
+
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+
+ position_embeddings = self.position_embeddings(position_ids)
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ if attention_mask.dim() == 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ elif attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask.unsqueeze(1)
+ else:
+ raise NotImplementedError
+
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ # Project input token features to have spcified hidden size
+ img_embedding_output = self.img_embedding(img_feats)
+
+ # We empirically observe that adding an additional learnable position embedding leads to more stable training
+ embeddings = position_embeddings + img_embedding_output
+
+ if self.use_img_layernorm:
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+
+ encoder_outputs = self.encoder(embeddings,
+ extended_attention_mask, head_mask=head_mask)
+ sequence_output = encoder_outputs[0]
+
+ outputs = (sequence_output,)
+ if self.config.output_hidden_states:
+ all_hidden_states = encoder_outputs[1]
+ outputs = outputs + (all_hidden_states,)
+ if self.config.output_attentions:
+ all_attentions = encoder_outputs[-1]
+ outputs = outputs + (all_attentions,)
+
+ return outputs
+
+class Graphormer(BertPreTrainedModel):
+ '''
+ The archtecture of a transformer encoder block we used in Graphormer
+ '''
+ def __init__(self, config):
+ super(Graphormer, self).__init__(config)
+ self.config = config
+ self.bert = EncoderBlock(config)
+ self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim)
+ self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
+
+ def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
+ next_sentence_label=None, position_ids=None, head_mask=None):
+ '''
+ # self.bert has three outputs
+ # predictions[0]: output tokens
+ # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
+ # predictions[2]: attentions, if enable "self.config.output_attentions"
+ '''
+ predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
+ attention_mask=attention_mask, head_mask=head_mask)
+
+ # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
+ pred_score = self.cls_head(predictions[0])
+ res_img_feats = self.residual(img_feats)
+ pred_score = pred_score + res_img_feats
+
+ if self.config.output_attentions and self.config.output_hidden_states:
+ return pred_score, predictions[1], predictions[-1]
+ else:
+ return pred_score
+
+
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/bert/modeling_utils.py b/src/custom_mesh_graphormer/modeling/bert/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcfab434a6ce1f7d77f34c47e4756724029663c9
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/bert/modeling_utils.py
@@ -0,0 +1 @@
+from transformers.modeling_utils import *
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy b/src/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy
new file mode 100644
index 0000000000000000000000000000000000000000..c15c7c4294d859ee037404876073a969c0da5524
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40dfaa71fcc7eed6966a6ed046311b7e8ea0eb9a5172b298e3df6fc4b6ec0eb0
+size 771808
diff --git a/src/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy b/src/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy
new file mode 100644
index 0000000000000000000000000000000000000000..dff7bedc5d08289a308299a6c82df39484e4b62b
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1835d64133d5f66bd80a814ab1c1dc0900ef01950f568320acf5f9390c1f2c8c
+size 937168
diff --git a/src/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl b/src/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..94bb3a9fbdbbc001b985f6fe36cb290773cc55b2
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b61cdb94a33582d626456752515624d7c558b5adcc997d13fb422963b5f791ed
+size 3447713
diff --git a/src/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl b/src/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..bc0e78214e1408bb7c8b9aa058fd0910fa08d248
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e3fb9ac790637539011258e211415d3bcae8daa2759c86b9046f4b371f0c423
+size 3447679
diff --git a/src/custom_mesh_graphormer/modeling/data/README.md b/src/custom_mesh_graphormer/modeling/data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e7cfc083291ab8ec837f5f63f2d04643a33659b8
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/README.md
@@ -0,0 +1,30 @@
+
+# Extra data
+Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
+
+Our code requires additional data to run smoothly.
+
+### J_regressor_extra.npy
+Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints.
+
+### J_regressor_h36m_correct.npy
+Joints regressor reflecting the Human3.6M joints.
+
+### mesh_downsampling.npz
+Extra file with precomputed downsampling for the SMPL body mesh.
+
+### mano_downsampling.npz
+Extra file with precomputed downsampling for the MANO hand mesh.
+
+### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
+SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/)
+
+### basicModel_m_lbs_10_207_0_v1.0.0.pkl
+SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
+
+### basicModel_f_lbs_10_207_0_v1.0.0.pkl
+SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
+
+### MANO_RIGHT.pkl
+MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/)
+
diff --git a/src/custom_mesh_graphormer/modeling/data/config.py b/src/custom_mesh_graphormer/modeling/data/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8eb40f7d5130616173489cbf8697d76da504c29
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/config.py
@@ -0,0 +1,47 @@
+"""
+This file contains definitions of useful data stuctures and the paths
+for the datasets and data files necessary to run the code.
+
+Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
+
+"""
+
+from pathlib import Path
+folder_path = Path(__file__).parent.parent
+JOINT_REGRESSOR_TRAIN_EXTRA = folder_path / 'data/J_regressor_extra.npy'
+JOINT_REGRESSOR_H36M_correct = folder_path / 'data/J_regressor_h36m_correct.npy'
+SMPL_FILE = folder_path / 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
+SMPL_Male = folder_path / 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
+SMPL_Female = folder_path / 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
+SMPL_sampling_matrix = folder_path / 'data/mesh_downsampling.npz'
+MANO_FILE = folder_path / 'data/MANO_RIGHT.pkl'
+MANO_sampling_matrix = folder_path / 'data/mano_downsampling.npz'
+
+JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
+
+
+"""
+We follow the body joint definition, loss functions, and evaluation metrics from
+open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
+
+Each dataset uses different sets of joints.
+We use a superset of 24 joints such that we include all joints from every dataset.
+If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
+The joints used here are:
+"""
+J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
+'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
+H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
+ 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
+J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
+H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
+
+"""
+We follow the hand joint definition and mesh topology from
+open source project Manopth (https://github.com/hassony2/manopth)
+
+The hand joints used here are:
+"""
+J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
+'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
+ROOT_INDEX = 0
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt
new file mode 100644
index 0000000000000000000000000000000000000000..13c417de4303fb93c67fc387e5fd0ece981725f0
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f99d80a96bbced27df6b6dee4fbdc01ee326e7e2691a79ca596ad03f57db8a6a
+size 21639
diff --git a/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_size.pt b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_size.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6d3c15248ea036b333ebc7c2daad0c81ed311dd9
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_size.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd137f9f2b1b8827251934784326a5b05f1415333101c13c849ed6b5eba6c3a4
+size 173
diff --git a/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_values.pt b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_values.pt
new file mode 100644
index 0000000000000000000000000000000000000000..304f445e207e41b2ed3663e0d7557a9b028291ae
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/mano_195_adjmat_values.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29945bf3dce6ce679fc6b35ed9007aa531f68e6e2e44a59d5ab69a643862f8f9
+size 5663
diff --git a/src/custom_mesh_graphormer/modeling/data/mano_downsampling.npz b/src/custom_mesh_graphormer/modeling/data/mano_downsampling.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1e2197737db3ffaab05713e49784c85c7b83afe9
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/mano_downsampling.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db2b23b0ede7c34039f1d8960e2e839552e0f58d4b34a8612b4065d8a47f9c80
+size 176509
diff --git a/src/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz b/src/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz
new file mode 100644
index 0000000000000000000000000000000000000000..ee14dddb06717c1312b81f3fa74a7ef2ea4357ac
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5683b656d9acd7f1558db832527beb0cc6b3b45388cf41a7979dab52e2c57477
+size 1720359
diff --git a/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt
new file mode 100644
index 0000000000000000000000000000000000000000..fd4155b9eb7adb6b1755215d0d992a1926b9a412
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cd71fd5e0ed2c55b909fad982a80fe3a0ddaf6e203721e030c9ea0246891f73
+size 48423
diff --git a/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt
new file mode 100644
index 0000000000000000000000000000000000000000..925e26b11289f418e10040b6d3880c157b573778
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2b87183ef2493228e014b5e725f59755a2c75f034684f7cca72c34f7a9d0ae0
+size 175
diff --git a/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5e0e1b8c594444b1a4df409a61e3217c56dbb171
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d098729e4bc1bdc21545ff99cbe0560c44210a38ff43db7b2858dbfa03e3073
+size 12359
diff --git a/src/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy b/src/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy
new file mode 100644
index 0000000000000000000000000000000000000000..18d0ded99e79da8a8a364e476f527658a8b3ee3a
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9dc2cfedd4901e31a9ac3c9f60bcaf7647ef41b067e68ceb15ac194b7b6714ae
+size 21128
diff --git a/src/custom_mesh_graphormer/modeling/hrnet/config/__init__.py b/src/custom_mesh_graphormer/modeling/hrnet/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59be7491faa1cb598b023802f68e0e15732f962b
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/hrnet/config/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from .default import _C as config
+from .default import update_config
+from .models import MODEL_EXTRAS
diff --git a/src/custom_mesh_graphormer/modeling/hrnet/config/default.py b/src/custom_mesh_graphormer/modeling/hrnet/config/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b9843467c25b9512497f9704e1873206eebfa4
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/hrnet/config/default.py
@@ -0,0 +1,138 @@
+
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from yacs.config import CfgNode as CN
+
+
+_C = CN()
+
+_C.OUTPUT_DIR = ''
+_C.LOG_DIR = ''
+_C.DATA_DIR = ''
+_C.GPUS = (0,)
+_C.WORKERS = 4
+_C.PRINT_FREQ = 20
+_C.AUTO_RESUME = False
+_C.PIN_MEMORY = True
+_C.RANK = 0
+
+# Cudnn related params
+_C.CUDNN = CN()
+_C.CUDNN.BENCHMARK = True
+_C.CUDNN.DETERMINISTIC = False
+_C.CUDNN.ENABLED = True
+
+# common params for NETWORK
+_C.MODEL = CN()
+_C.MODEL.NAME = 'cls_hrnet'
+_C.MODEL.INIT_WEIGHTS = True
+_C.MODEL.PRETRAINED = ''
+_C.MODEL.NUM_JOINTS = 17
+_C.MODEL.NUM_CLASSES = 1000
+_C.MODEL.TAG_PER_JOINT = True
+_C.MODEL.TARGET_TYPE = 'gaussian'
+_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
+_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
+_C.MODEL.SIGMA = 2
+_C.MODEL.EXTRA = CN(new_allowed=True)
+
+_C.LOSS = CN()
+_C.LOSS.USE_OHKM = False
+_C.LOSS.TOPK = 8
+_C.LOSS.USE_TARGET_WEIGHT = True
+_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
+
+# DATASET related params
+_C.DATASET = CN()
+_C.DATASET.ROOT = ''
+_C.DATASET.DATASET = 'mpii'
+_C.DATASET.TRAIN_SET = 'train'
+_C.DATASET.TEST_SET = 'valid'
+_C.DATASET.DATA_FORMAT = 'jpg'
+_C.DATASET.HYBRID_JOINTS_TYPE = ''
+_C.DATASET.SELECT_DATA = False
+
+# training data augmentation
+_C.DATASET.FLIP = True
+_C.DATASET.SCALE_FACTOR = 0.25
+_C.DATASET.ROT_FACTOR = 30
+_C.DATASET.PROB_HALF_BODY = 0.0
+_C.DATASET.NUM_JOINTS_HALF_BODY = 8
+_C.DATASET.COLOR_RGB = False
+
+# train
+_C.TRAIN = CN()
+
+_C.TRAIN.LR_FACTOR = 0.1
+_C.TRAIN.LR_STEP = [90, 110]
+_C.TRAIN.LR = 0.001
+
+_C.TRAIN.OPTIMIZER = 'adam'
+_C.TRAIN.MOMENTUM = 0.9
+_C.TRAIN.WD = 0.0001
+_C.TRAIN.NESTEROV = False
+_C.TRAIN.GAMMA1 = 0.99
+_C.TRAIN.GAMMA2 = 0.0
+
+_C.TRAIN.BEGIN_EPOCH = 0
+_C.TRAIN.END_EPOCH = 140
+
+_C.TRAIN.RESUME = False
+_C.TRAIN.CHECKPOINT = ''
+
+_C.TRAIN.BATCH_SIZE_PER_GPU = 32
+_C.TRAIN.SHUFFLE = True
+
+# testing
+_C.TEST = CN()
+
+# size of images for each device
+_C.TEST.BATCH_SIZE_PER_GPU = 32
+# Test Model Epoch
+_C.TEST.FLIP_TEST = False
+_C.TEST.POST_PROCESS = False
+_C.TEST.SHIFT_HEATMAP = False
+
+_C.TEST.USE_GT_BBOX = False
+
+# nms
+_C.TEST.IMAGE_THRE = 0.1
+_C.TEST.NMS_THRE = 0.6
+_C.TEST.SOFT_NMS = False
+_C.TEST.OKS_THRE = 0.5
+_C.TEST.IN_VIS_THRE = 0.0
+_C.TEST.COCO_BBOX_FILE = ''
+_C.TEST.BBOX_THRE = 1.0
+_C.TEST.MODEL_FILE = ''
+
+# debug
+_C.DEBUG = CN()
+_C.DEBUG.DEBUG = False
+_C.DEBUG.SAVE_BATCH_IMAGES_GT = False
+_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
+_C.DEBUG.SAVE_HEATMAPS_GT = False
+_C.DEBUG.SAVE_HEATMAPS_PRED = False
+
+
+def update_config(cfg, config_file):
+ cfg.defrost()
+ cfg.merge_from_file(config_file)
+ cfg.freeze()
+
+
+if __name__ == '__main__':
+ import sys
+ with open(sys.argv[1], 'w') as f:
+ print(_C, file=f)
+
diff --git a/src/custom_mesh_graphormer/modeling/hrnet/config/models.py b/src/custom_mesh_graphormer/modeling/hrnet/config/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a73bc9ffd00bfe081496174c808edb93f240cfb
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/hrnet/config/models.py
@@ -0,0 +1,47 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Create by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from yacs.config import CfgNode as CN
+
+# high_resoluton_net related params for classification
+POSE_HIGH_RESOLUTION_NET = CN()
+POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
+POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
+POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
+POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True
+
+POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
+POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
+POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
+POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
+POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
+POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
+POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
+
+POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
+POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
+POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
+POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
+POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
+POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
+POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
+
+POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
+POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
+POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
+POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
+POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
+POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
+POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
+
+MODEL_EXTRAS = {
+ 'cls_hrnet': POSE_HIGH_RESOLUTION_NET,
+}
diff --git a/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py b/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..3388c230c981234b454b14224713720f7db04ce1
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py
@@ -0,0 +1,523 @@
+
+
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+# Modified by Kevin Lin (keli@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+import functools
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch._utils
+import torch.nn.functional as F
+import code
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(False)
+
+ def _check_branches(self, num_branches, blocks, num_blocks,
+ num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
+ momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index], stride, downsample))
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_inchannels[i],
+ 1,
+ 1,
+ 0,
+ bias=False),
+ nn.BatchNorm2d(num_inchannels[i],
+ momentum=BN_MOMENTUM),
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i-j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM),
+ nn.ReLU(False)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+blocks_dict = {
+ 'BASIC': BasicBlock,
+ 'BOTTLENECK': Bottleneck
+}
+
+
+class HighResolutionNet(nn.Module):
+
+ def __init__(self, cfg, **kwargs):
+ super(HighResolutionNet, self).__init__()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+ stage1_out_channel = block.expansion*num_channels
+
+ self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer(
+ [stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels, multi_scale_output=True)
+
+ # Classification Head
+ self.incre_modules, self.downsamp_modules, \
+ self.final_layer = self._make_head(pre_stage_channels)
+
+ self.classifier = nn.Linear(2048, 1000)
+
+ def _make_head(self, pre_stage_channels):
+ head_block = Bottleneck
+ head_channels = [32, 64, 128, 256]
+
+ # Increasing the #channels on each resolution
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
+ incre_modules = []
+ for i, channels in enumerate(pre_stage_channels):
+ incre_module = self._make_layer(head_block,
+ channels,
+ head_channels[i],
+ 1,
+ stride=1)
+ incre_modules.append(incre_module)
+ incre_modules = nn.ModuleList(incre_modules)
+
+ # downsampling modules
+ downsamp_modules = []
+ for i in range(len(pre_stage_channels)-1):
+ in_channels = head_channels[i] * head_block.expansion
+ out_channels = head_channels[i+1] * head_block.expansion
+
+ downsamp_module = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1),
+ nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+
+ downsamp_modules.append(downsamp_module)
+ downsamp_modules = nn.ModuleList(downsamp_modules)
+
+ final_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=head_channels[3] * head_block.expansion,
+ out_channels=2048,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ ),
+ nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+
+ return incre_modules, downsamp_modules, final_layer
+
+ def _make_transition_layer(
+ self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ 3,
+ 1,
+ 1,
+ bias=False),
+ nn.BatchNorm2d(
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i+1-num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i-num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(
+ inchannels, outchannels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride, downsample))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, num_inchannels,
+ multi_scale_output=True):
+ num_modules = layer_config['NUM_MODULES']
+ num_branches = layer_config['NUM_BRANCHES']
+ num_blocks = layer_config['NUM_BLOCKS']
+ num_channels = layer_config['NUM_CHANNELS']
+ block = blocks_dict[layer_config['BLOCK']]
+ fuse_method = layer_config['FUSE_METHOD']
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+
+ modules.append(
+ HighResolutionModule(num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output)
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ # Classification Head
+ y = self.incre_modules[0](y_list[0])
+ for i in range(len(self.downsamp_modules)):
+ y = self.incre_modules[i+1](y_list[i+1]) + \
+ self.downsamp_modules[i](y)
+
+ y = self.final_layer(y)
+
+ if torch._C._get_tracing_state():
+ y = y.flatten(start_dim=2).mean(dim=2)
+ else:
+ y = F.avg_pool2d(y, kernel_size=y.size()
+ [2:]).view(y.size(0), -1)
+
+ # y = self.classifier(y)
+
+ return y
+
+ def init_weights(self, pretrained='',):
+ logger.info('=> init weights from normal distribution')
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ if os.path.isfile(pretrained):
+ pretrained_dict = torch.load(pretrained)
+ logger.info('=> loading pretrained model {}'.format(pretrained))
+ print('=> loading pretrained model {}'.format(pretrained))
+ model_dict = self.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()}
+ # for k, _ in pretrained_dict.items():
+ # logger.info(
+ # '=> loading {} pretrained model {}'.format(k, pretrained))
+ # print('=> loading {} pretrained model {}'.format(k, pretrained))
+ model_dict.update(pretrained_dict)
+ self.load_state_dict(model_dict)
+ # code.interact(local=locals())
+
+def get_cls_net(config, pretrained, **kwargs):
+ model = HighResolutionNet(config, **kwargs)
+ model.init_weights(pretrained=pretrained)
+ return model
diff --git a/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py b/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c44c948335aa0de8781814beaa2404ea34e6574
--- /dev/null
+++ b/src/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py
@@ -0,0 +1,524 @@
+
+
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+# Modified by Kevin Lin (keli@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+import functools
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch._utils
+import torch.nn.functional as F
+import code
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(False)
+
+ def _check_branches(self, num_branches, blocks, num_blocks,
+ num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
+ momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index], stride, downsample))
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_inchannels[i],
+ 1,
+ 1,
+ 0,
+ bias=False),
+ nn.BatchNorm2d(num_inchannels[i],
+ momentum=BN_MOMENTUM),
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i-j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM),
+ nn.ReLU(False)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+blocks_dict = {
+ 'BASIC': BasicBlock,
+ 'BOTTLENECK': Bottleneck
+}
+
+
+class HighResolutionNet(nn.Module):
+
+ def __init__(self, cfg, **kwargs):
+ super(HighResolutionNet, self).__init__()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+ stage1_out_channel = block.expansion*num_channels
+
+ self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer(
+ [stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels, multi_scale_output=True)
+
+ # Classification Head
+ self.incre_modules, self.downsamp_modules, \
+ self.final_layer = self._make_head(pre_stage_channels)
+
+ self.classifier = nn.Linear(2048, 1000)
+
+ def _make_head(self, pre_stage_channels):
+ head_block = Bottleneck
+ head_channels = [32, 64, 128, 256]
+
+ # Increasing the #channels on each resolution
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
+ incre_modules = []
+ for i, channels in enumerate(pre_stage_channels):
+ incre_module = self._make_layer(head_block,
+ channels,
+ head_channels[i],
+ 1,
+ stride=1)
+ incre_modules.append(incre_module)
+ incre_modules = nn.ModuleList(incre_modules)
+
+ # downsampling modules
+ downsamp_modules = []
+ for i in range(len(pre_stage_channels)-1):
+ in_channels = head_channels[i] * head_block.expansion
+ out_channels = head_channels[i+1] * head_block.expansion
+
+ downsamp_module = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1),
+ nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+
+ downsamp_modules.append(downsamp_module)
+ downsamp_modules = nn.ModuleList(downsamp_modules)
+
+ final_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=head_channels[3] * head_block.expansion,
+ out_channels=2048,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ ),
+ nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+
+ return incre_modules, downsamp_modules, final_layer
+
+ def _make_transition_layer(
+ self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ 3,
+ 1,
+ 1,
+ bias=False),
+ nn.BatchNorm2d(
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i+1-num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i-num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(
+ inchannels, outchannels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride, downsample))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, num_inchannels,
+ multi_scale_output=True):
+ num_modules = layer_config['NUM_MODULES']
+ num_branches = layer_config['NUM_BRANCHES']
+ num_blocks = layer_config['NUM_BLOCKS']
+ num_channels = layer_config['NUM_CHANNELS']
+ block = blocks_dict[layer_config['BLOCK']]
+ fuse_method = layer_config['FUSE_METHOD']
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+
+ modules.append(
+ HighResolutionModule(num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output)
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ # Classification Head
+ y = self.incre_modules[0](y_list[0])
+ for i in range(len(self.downsamp_modules)):
+ y = self.incre_modules[i+1](y_list[i+1]) + \
+ self.downsamp_modules[i](y)
+
+ yy = self.final_layer(y)
+
+ if torch._C._get_tracing_state():
+ yy = yy.flatten(start_dim=2).mean(dim=2)
+ else:
+ yy = F.avg_pool2d(yy, kernel_size=yy.size()
+ [2:]).view(yy.size(0), -1)
+
+ # y = self.classifier(y)
+ return yy, y
+
+
+
+ def init_weights(self, pretrained='',):
+ logger.info('=> init weights from normal distribution')
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ if os.path.isfile(pretrained):
+ pretrained_dict = torch.load(pretrained)
+ logger.info('=> loading pretrained model {}'.format(pretrained))
+ print('=> loading pretrained model {}'.format(pretrained))
+ model_dict = self.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()}
+ # for k, _ in pretrained_dict.items():
+ # logger.info(
+ # '=> loading {} pretrained model {}'.format(k, pretrained))
+ # print('=> loading {} pretrained model {}'.format(k, pretrained))
+ model_dict.update(pretrained_dict)
+ self.load_state_dict(model_dict)
+ # code.interact(local=locals())
+
+def get_cls_net_gridfeat(config, pretrained, **kwargs):
+ model = HighResolutionNet(config, **kwargs)
+ model.init_weights(pretrained=pretrained)
+ return model
diff --git a/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py b/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..83685c4eb09f4ef31b64b0070b514113456e6f9a
--- /dev/null
+++ b/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py
@@ -0,0 +1,750 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Training and evaluation codes for
+3D human body mesh reconstruction from an image
+"""
+
+from __future__ import absolute_import, division, print_function
+import argparse
+import os
+import os.path as op
+import code
+import json
+import time
+import datetime
+import torch
+import torchvision.models as models
+from torchvision.utils import make_grid
+import gc
+import numpy as np
+import cv2
+from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
+from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
+from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh
+from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
+from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
+from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
+import custom_mesh_graphormer.modeling.data.config as cfg
+from custom_mesh_graphormer.datasets.build import make_data_loader
+
+from custom_mesh_graphormer.utils.logger import setup_logger
+from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
+from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed
+from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
+from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test
+from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error
+from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection
+
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+from azureml.core.run import Run
+aml_run = Run.get_context()
+
+def save_checkpoint(model, args, epoch, iteration, num_trial=10):
+ checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
+ epoch, iteration))
+ if not is_main_process():
+ return checkpoint_dir
+ mkdir(checkpoint_dir)
+ model_to_save = model.module if hasattr(model, 'module') else model
+ for i in range(num_trial):
+ try:
+ torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
+ torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
+ torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
+ logger.info("Save checkpoint to {}".format(checkpoint_dir))
+ break
+ except:
+ pass
+ else:
+ logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
+ return checkpoint_dir
+
+def save_scores(args, split, mpjpe, pampjpe, mpve):
+ eval_log = []
+ res = {}
+ res['mPJPE'] = mpjpe
+ res['PAmPJPE'] = pampjpe
+ res['mPVE'] = mpve
+ eval_log.append(res)
+ with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f:
+ json.dump(eval_log, f)
+ logger.info("Save eval scores to {}".format(args.output_dir))
+ return
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """
+ Sets the learning rate to the initial LR decayed by x every y epochs
+ x = 0.1, y = args.num_train_epochs/2.0 = 100
+ """
+ lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def mean_per_joint_position_error(pred, gt, has_3d_joints):
+ """
+ Compute mPJPE
+ """
+ gt = gt[has_3d_joints == 1]
+ gt = gt[:, :, :-1]
+ pred = pred[has_3d_joints == 1]
+
+ with torch.no_grad():
+ gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2
+ gt = gt - gt_pelvis[:, None, :]
+ pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2
+ pred = pred - pred_pelvis[:, None, :]
+ error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
+ return error
+
+def mean_per_vertex_error(pred, gt, has_smpl):
+ """
+ Compute mPVE
+ """
+ pred = pred[has_smpl == 1]
+ gt = gt[has_smpl == 1]
+ with torch.no_grad():
+ error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
+ return error
+
+def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
+ """
+ Compute 2D reprojection loss if 2D keypoint annotations are available.
+ The confidence (conf) is binary and indicates whether the keypoints exist or not.
+ """
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
+ loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
+ return loss
+
+def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device):
+ """
+ Compute 3D keypoint loss if 3D keypoint annotations are available.
+ """
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
+ gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
+ conf = conf[has_pose_3d == 1]
+ pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
+ if len(gt_keypoints_3d) > 0:
+ gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2
+ gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
+ pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2
+ pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
+ return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
+ else:
+ return torch.FloatTensor(1).fill_(0.).to(device)
+
+def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device):
+ """
+ Compute per-vertex loss if vertex annotations are available.
+ """
+ pred_vertices_with_shape = pred_vertices[has_smpl == 1]
+ gt_vertices_with_shape = gt_vertices[has_smpl == 1]
+ if len(gt_vertices_with_shape) > 0:
+ return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
+ else:
+ return torch.FloatTensor(1).fill_(0.).to(device)
+
+def rectify_pose(pose):
+ pose = pose.copy()
+ R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
+ R_root = cv2.Rodrigues(pose[:3])[0]
+ new_root = R_root.dot(R_mod)
+ pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
+ return pose
+
+def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer):
+ smpl.eval()
+ max_iter = len(train_dataloader)
+ iters_per_epoch = max_iter // args.num_train_epochs
+ if iters_per_epoch<1000:
+ args.logging_steps = 500
+
+ optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
+ lr=args.lr,
+ betas=(0.9, 0.999),
+ weight_decay=0)
+
+ # define loss function (criterion) and optimizer
+ criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
+ criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
+ criterion_vertices = torch.nn.L1Loss().to(device)
+
+ if args.distributed:
+ Graphormer_model = torch.nn.parallel.DistributedDataParallel(
+ Graphormer_model, device_ids=[args.local_rank],
+ output_device=args.local_rank,
+ find_unused_parameters=True,
+ )
+
+ logger.info(
+ ' '.join(
+ ['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',]
+ ).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs)
+ )
+
+ start_training_time = time.time()
+ end = time.time()
+ Graphormer_model.train()
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+ log_losses = AverageMeter()
+ log_loss_2djoints = AverageMeter()
+ log_loss_3djoints = AverageMeter()
+ log_loss_vertices = AverageMeter()
+ log_eval_metrics = EvalMetricsLogger()
+
+ for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
+ # gc.collect()
+ # torch.cuda.empty_cache()
+ Graphormer_model.train()
+ iteration += 1
+ epoch = iteration // iters_per_epoch
+ batch_size = images.size(0)
+ adjust_learning_rate(optimizer, epoch, args)
+ data_time.update(time.time() - end)
+
+ images = images.to(device)
+ gt_2d_joints = annotations['joints_2d'].to(device)
+ gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:]
+ has_2d_joints = annotations['has_2d_joints'].to(device)
+
+ gt_3d_joints = annotations['joints_3d'].to(device)
+ gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
+ gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
+ gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
+ has_3d_joints = annotations['has_3d_joints'].to(device)
+
+ gt_pose = annotations['pose'].to(device)
+ gt_betas = annotations['betas'].to(device)
+ has_smpl = annotations['has_smpl'].to(device)
+ mjm_mask = annotations['mjm_mask'].to(device)
+ mvm_mask = annotations['mvm_mask'].to(device)
+
+ # generate simplified mesh
+ gt_vertices = smpl(gt_pose, gt_betas)
+ gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2)
+ gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
+
+ # normalize gt based on smpl's pelvis
+ gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
+ gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
+ gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
+
+ # prepare masks for mask vertex/joint modeling
+ mjm_mask_ = mjm_mask.expand(-1,-1,2051)
+ mvm_mask_ = mvm_mask.expand(-1,-1,2051)
+ meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True)
+
+ # normalize gt based on smpl's pelvis
+ gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :]
+ gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
+
+ # obtain 2d joints, which are projected from 3d joints of smpl mesh
+ pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
+ pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
+
+ # compute 3d joint loss (where the joints are directly output from transformer)
+ loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device)
+ # compute 3d vertex loss
+ loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \
+ args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \
+ args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) )
+ # compute 3d joint loss (where the joints are regressed from full mesh)
+ loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device)
+ # compute 2d joint loss
+ loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
+ keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints)
+
+ loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
+
+ # we empirically use hyperparameters to balance difference losses
+ loss = args.joints_loss_weight*loss_3d_joints + \
+ args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
+
+ # update logs
+ log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
+ log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
+ log_loss_vertices.update(loss_vertices.item(), batch_size)
+ log_losses.update(loss.item(), batch_size)
+
+ # back prop
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if iteration % args.logging_steps == 0 or iteration == max_iter:
+ eta_seconds = batch_time.avg * (max_iter - iteration)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ logger.info(
+ ' '.join(
+ ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
+ ).format(eta=eta_string, ep=epoch, iter=iteration,
+ memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+ + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
+ log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
+ optimizer.param_groups[0]['lr'])
+ )
+
+ aml_run.log(name='Loss', value=float(log_losses.avg))
+ aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
+ aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
+ aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
+
+ visual_imgs = visualize_mesh( renderer,
+ annotations['ori_img'].detach(),
+ annotations['joints_2d'].detach(),
+ pred_vertices.detach(),
+ pred_camera.detach(),
+ pred_2d_joints_from_smpl.detach())
+ visual_imgs = visual_imgs.transpose(0,1)
+ visual_imgs = visual_imgs.transpose(1,2)
+ visual_imgs = np.asarray(visual_imgs)
+
+ if is_main_process()==True:
+ stamp = str(epoch) + '_' + str(iteration)
+ temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
+ cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
+ aml_run.log_image(name='visual results', path=temp_fname)
+
+ if iteration % iters_per_epoch == 0:
+ val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
+ Graphormer_model,
+ criterion_keypoints,
+ criterion_vertices,
+ epoch,
+ smpl,
+ mesh_sampler)
+ aml_run.log(name='mPVE', value=float(1000*val_mPVE))
+ aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
+ aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
+ logger.info(
+ ' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
+ + ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count)
+ )
+
+ if val_PAmPJPE0:
+ mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) )
+ if len(error_joints)>0:
+ mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) )
+ if len(error_joints_pa)>0:
+ PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) )
+
+ val_mPVE = all_gather(float(mPVE.avg))
+ val_mPVE = sum(val_mPVE)/len(val_mPVE)
+ val_mPJPE = all_gather(float(mPJPE.avg))
+ val_mPJPE = sum(val_mPJPE)/len(val_mPJPE)
+
+ val_PAmPJPE = all_gather(float(PAmPJPE.avg))
+ val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE)
+
+ val_count = all_gather(float(mPVE.count))
+ val_count = sum(val_count)
+
+ return val_mPVE, val_mPJPE, val_PAmPJPE, val_count
+
+
+def visualize_mesh( renderer,
+ images,
+ gt_keypoints_2d,
+ pred_vertices,
+ pred_camera,
+ pred_keypoints_2d):
+ """Tensorboard logging."""
+ gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
+ to_lsp = list(range(14))
+ rend_imgs = []
+ batch_size = pred_vertices.shape[0]
+ # Do visualization for the first 6 images of the batch
+ for i in range(min(batch_size, 10)):
+ img = images[i].cpu().numpy().transpose(1,2,0)
+ # Get LSP keypoints from the full list of keypoints
+ gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
+ pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
+ # Get predict vertices for the particular example
+ vertices = pred_vertices[i].cpu().numpy()
+ cam = pred_camera[i].cpu().numpy()
+ # Visualize reconstruction and detected pose
+ rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
+ rend_img = rend_img.transpose(2,0,1)
+ rend_imgs.append(torch.from_numpy(rend_img))
+ rend_imgs = make_grid(rend_imgs, nrow=1)
+ return rend_imgs
+
+def visualize_mesh_test( renderer,
+ images,
+ gt_keypoints_2d,
+ pred_vertices,
+ pred_camera,
+ pred_keypoints_2d,
+ PAmPJPE_h36m_j14):
+ """Tensorboard logging."""
+ gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
+ to_lsp = list(range(14))
+ rend_imgs = []
+ batch_size = pred_vertices.shape[0]
+ # Do visualization for the first 6 images of the batch
+ for i in range(min(batch_size, 10)):
+ img = images[i].cpu().numpy().transpose(1,2,0)
+ # Get LSP keypoints from the full list of keypoints
+ gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
+ pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
+ # Get predict vertices for the particular example
+ vertices = pred_vertices[i].cpu().numpy()
+ cam = pred_camera[i].cpu().numpy()
+ score = PAmPJPE_h36m_j14[i]
+ # Visualize reconstruction and detected pose
+ rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
+ rend_img = rend_img.transpose(2,0,1)
+ rend_imgs.append(torch.from_numpy(rend_img))
+ rend_imgs = make_grid(rend_imgs, nrow=1)
+ return rend_imgs
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ #########################################################
+ # Data related arguments
+ #########################################################
+ parser.add_argument("--data_dir", default='datasets', type=str, required=False,
+ help="Directory with all datasets, each in one subfolder")
+ parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
+ help="Yaml file with all data for training.")
+ parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
+ help="Yaml file with all data for validation.")
+ parser.add_argument("--num_workers", default=4, type=int,
+ help="Workers in dataloader.")
+ parser.add_argument("--img_scale_factor", default=1, type=int,
+ help="adjust image resolution.")
+ #########################################################
+ # Loading/saving checkpoints
+ #########################################################
+ parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
+ help="Path to pre-trained transformer model or model type.")
+ parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
+ help="Path to specific checkpoint for resume training.")
+ parser.add_argument("--output_dir", default='output/', type=str, required=False,
+ help="The output directory to save checkpoint and test results.")
+ parser.add_argument("--config_name", default="", type=str,
+ help="Pretrained config name or path if not the same as model_name.")
+ #########################################################
+ # Training parameters
+ #########################################################
+ parser.add_argument("--per_gpu_train_batch_size", default=30, type=int,
+ help="Batch size per GPU/CPU for training.")
+ parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int,
+ help="Batch size per GPU/CPU for evaluation.")
+ parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
+ help="The initial lr.")
+ parser.add_argument("--num_train_epochs", default=200, type=int,
+ help="Total number of training epochs to perform.")
+ parser.add_argument("--vertices_loss_weight", default=100.0, type=float)
+ parser.add_argument("--joints_loss_weight", default=1000.0, type=float)
+ parser.add_argument("--vloss_w_full", default=0.33, type=float)
+ parser.add_argument("--vloss_w_sub", default=0.33, type=float)
+ parser.add_argument("--vloss_w_sub2", default=0.33, type=float)
+ parser.add_argument("--drop_out", default=0.1, type=float,
+ help="Drop out ratio in BERT.")
+ #########################################################
+ # Model architectures
+ #########################################################
+ parser.add_argument('-a', '--arch', default='hrnet-w64',
+ help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
+ parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--hidden_size", default=-1, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
+ help="Update model config if given. Note that the division of "
+ "hidden_size / num_attention_heads should be in integer.")
+ parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
+ help="Update model config if given.")
+ parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--which_gcn", default='0,0,1', type=str,
+ help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
+ parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
+ parser.add_argument("--interm_size_scale", default=2, type=int)
+ #########################################################
+ # Others
+ #########################################################
+ parser.add_argument("--run_eval_only", default=False, action='store_true',)
+ parser.add_argument('--logging_steps', type=int, default=1000,
+ help="Log every X steps.")
+ parser.add_argument("--device", type=str, default='cuda',
+ help="cuda or cpu")
+ parser.add_argument('--seed', type=int, default=88,
+ help="random seed for initialization.")
+ parser.add_argument("--local_rank", type=int, default=0,
+ help="For distributed training.")
+
+
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ global logger
+ # Setup CUDA, GPU & distributed training
+ args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
+ os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
+ print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
+
+ args.distributed = args.num_gpus > 1
+ args.device = torch.device(args.device)
+ if args.distributed:
+ print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus))
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(
+ backend='nccl', init_method='env://'
+ )
+ local_rank = int(os.environ["LOCAL_RANK"])
+ args.device = torch.device("cuda", local_rank)
+ synchronize()
+
+ mkdir(args.output_dir)
+ logger = setup_logger("Graphormer", args.output_dir, get_rank())
+ set_seed(args.seed, args.num_gpus)
+ logger.info("Using {} GPUs".format(args.num_gpus))
+
+ # Mesh and SMPL utils
+ smpl = SMPL().to(args.device)
+ mesh_sampler = Mesh()
+
+ # Renderer for visualization
+ renderer = Renderer(faces=smpl.faces.cpu().numpy())
+
+ # Load model
+ trans_encoder = []
+
+ input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
+ hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
+ output_feat_dim = input_feat_dim[1:] + [3]
+
+ # which encoder block to have graph convs
+ which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
+
+ if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
+ # if only run eval, load checkpoint
+ logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
+ _model = torch.load(args.resume_checkpoint)
+ else:
+ # init three transformer-encoder blocks in a loop
+ for i in range(len(output_feat_dim)):
+ config_class, model_class = BertConfig, Graphormer
+ config = config_class.from_pretrained(args.config_name if args.config_name \
+ else args.model_name_or_path)
+
+ config.output_attentions = False
+ config.hidden_dropout_prob = args.drop_out
+ config.img_feature_dim = input_feat_dim[i]
+ config.output_feature_dim = output_feat_dim[i]
+ args.hidden_size = hidden_feat_dim[i]
+ args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
+
+ if which_blk_graph[i]==1:
+ config.graph_conv = True
+ logger.info("Add Graph Conv")
+ else:
+ config.graph_conv = False
+
+ config.mesh_type = args.mesh_type
+
+ # update model structure if specified in arguments
+ update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
+
+ for idx, param in enumerate(update_params):
+ arg_param = getattr(args, param)
+ config_param = getattr(config, param)
+ if arg_param > 0 and arg_param != config_param:
+ logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
+ setattr(config, param, arg_param)
+
+ # init a transformer encoder and append it to a list
+ assert config.hidden_size % config.num_attention_heads == 0
+ model = model_class(config=config)
+ logger.info("Init model from scratch.")
+ trans_encoder.append(model)
+
+
+ # init ImageNet pre-trained backbone model
+ if args.arch=='hrnet':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w40 model')
+ elif args.arch=='hrnet-w64':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w64 model')
+ else:
+ print("=> using pre-trained model '{}'".format(args.arch))
+ backbone = models.__dict__[args.arch](pretrained=True)
+ # remove the last fc layer
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
+
+
+ trans_encoder = torch.nn.Sequential(*trans_encoder)
+ total_params = sum(p.numel() for p in trans_encoder.parameters())
+ logger.info('Graphormer encoders total parameters: {}'.format(total_params))
+ backbone_total_params = sum(p.numel() for p in backbone.parameters())
+ logger.info('Backbone total parameters: {}'.format(backbone_total_params))
+
+ # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
+ _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
+
+ if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
+ # for fine-tuning or resume training or inference, load weights from checkpoint
+ logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
+ # workaround approach to load sparse tensor in graph conv.
+ states = torch.load(args.resume_checkpoint)
+ # states = checkpoint_loaded.state_dict()
+ for k, v in states.items():
+ states[k] = v.cpu()
+ # del checkpoint_loaded
+ _model.load_state_dict(states, strict=False)
+ del states
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+ _model.to(args.device)
+ logger.info("Training parameters %s", args)
+
+ if args.run_eval_only==True:
+ val_dataloader = make_data_loader(args, args.val_yaml,
+ args.distributed, is_train=False, scale_factor=args.img_scale_factor)
+ run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler)
+
+ else:
+ train_dataloader = make_data_loader(args, args.train_yaml,
+ args.distributed, is_train=True, scale_factor=args.img_scale_factor)
+ val_dataloader = make_data_loader(args, args.val_yaml,
+ args.distributed, is_train=False, scale_factor=args.img_scale_factor)
+ run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer)
+
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py b/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d67fc92740d2a34312d5a91f4facfe3ab1dd914
--- /dev/null
+++ b/src/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py
@@ -0,0 +1,351 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+End-to-end inference codes for
+3D human body mesh reconstruction from an image
+"""
+
+from __future__ import absolute_import, division, print_function
+import argparse
+import os
+import os.path as op
+import code
+import json
+import time
+import datetime
+import torch
+import torchvision.models as models
+from torchvision.utils import make_grid
+import gc
+import numpy as np
+import cv2
+from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
+from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
+from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh
+from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
+from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
+from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
+import custom_mesh_graphormer.modeling.data.config as cfg
+from custom_mesh_graphormer.datasets.build import make_data_loader
+
+from custom_mesh_graphormer.utils.logger import setup_logger
+from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
+from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed
+from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
+from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
+from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error
+from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection
+
+from PIL import Image
+from torchvision import transforms
+
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+transform = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])])
+
+transform_visualize = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor()])
+
+def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler):
+ # switch to evaluate mode
+ Graphormer_model.eval()
+ smpl.eval()
+ with torch.no_grad():
+ for image_file in image_list:
+ if 'pred' not in image_file:
+ att_all = []
+ img = Image.open(image_file)
+ img_tensor = transform(img)
+ img_visual = transform_visualize(img)
+
+ batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
+ batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler)
+
+ # obtain 3d joints from full mesh
+ pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
+
+ pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
+ pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
+
+ # save attantion
+ att_max_value = att[-1]
+ att_cpu = np.asarray(att_max_value.cpu().detach())
+ att_all.append(att_cpu)
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
+ # obtain 2d joints, which are projected from 3d joints of smpl mesh
+ pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
+ pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera)
+ visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
+ pred_vertices[0].detach(),
+ pred_camera.detach())
+ # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
+ # pred_vertices[0].detach(),
+ # pred_vertices_sub2[0].detach(),
+ # pred_2d_431_vertices_from_smpl[0].detach(),
+ # pred_2d_joints_from_smpl[0].detach(),
+ # pred_camera.detach(),
+ # att[-1][0].detach())
+
+ visual_imgs = visual_imgs_output.transpose(1,2,0)
+ visual_imgs = np.asarray(visual_imgs)
+
+ temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
+ print('save to ', temp_fname)
+ cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
+
+ return
+
+def visualize_mesh( renderer, images,
+ pred_vertices_full,
+ pred_camera):
+ img = images.cpu().numpy().transpose(1,2,0)
+ # Get predict vertices for the particular example
+ vertices_full = pred_vertices_full.cpu().numpy()
+ cam = pred_camera.cpu().numpy()
+ # Visualize only mesh reconstruction
+ rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
+ rend_img = rend_img.transpose(2,0,1)
+ return rend_img
+
+def visualize_mesh_and_attention( renderer, images,
+ pred_vertices_full,
+ pred_vertices,
+ pred_2d_vertices,
+ pred_2d_joints,
+ pred_camera,
+ attention):
+ img = images.cpu().numpy().transpose(1,2,0)
+ # Get predict vertices for the particular example
+ vertices_full = pred_vertices_full.cpu().numpy()
+ vertices = pred_vertices.cpu().numpy()
+ vertices_2d = pred_2d_vertices.cpu().numpy()
+ joints_2d = pred_2d_joints.cpu().numpy()
+ cam = pred_camera.cpu().numpy()
+ att = attention.cpu().numpy()
+ # Visualize reconstruction and attention
+ rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
+ rend_img = rend_img.transpose(2,0,1)
+ return rend_img
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ #########################################################
+ # Data related arguments
+ #########################################################
+ parser.add_argument("--num_workers", default=4, type=int,
+ help="Workers in dataloader.")
+ parser.add_argument("--img_scale_factor", default=1, type=int,
+ help="adjust image resolution.")
+ parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str,
+ help="test data")
+ #########################################################
+ # Loading/saving checkpoints
+ #########################################################
+ parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
+ help="Path to pre-trained transformer model or model type.")
+ parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
+ help="Path to specific checkpoint for resume training.")
+ parser.add_argument("--output_dir", default='output/', type=str, required=False,
+ help="The output directory to save checkpoint and test results.")
+ parser.add_argument("--config_name", default="", type=str,
+ help="Pretrained config name or path if not the same as model_name.")
+ #########################################################
+ # Model architectures
+ #########################################################
+ parser.add_argument('-a', '--arch', default='hrnet-w64',
+ help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
+ parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--hidden_size", default=-1, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
+ help="Update model config if given. Note that the division of "
+ "hidden_size / num_attention_heads should be in integer.")
+ parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
+ help="Update model config if given.")
+ parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--which_gcn", default='0,0,1', type=str,
+ help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
+ parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
+ parser.add_argument("--interm_size_scale", default=2, type=int)
+ #########################################################
+ # Others
+ #########################################################
+ parser.add_argument("--run_eval_only", default=True, action='store_true',)
+ parser.add_argument("--device", type=str, default='cuda',
+ help="cuda or cpu")
+ parser.add_argument('--seed', type=int, default=88,
+ help="random seed for initialization.")
+
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ global logger
+ # Setup CUDA, GPU & distributed training
+ args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
+ os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
+ print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
+
+ args.distributed = args.num_gpus > 1
+ args.device = torch.device(args.device)
+
+ mkdir(args.output_dir)
+ logger = setup_logger("Graphormer", args.output_dir, get_rank())
+ set_seed(args.seed, args.num_gpus)
+ logger.info("Using {} GPUs".format(args.num_gpus))
+
+ # Mesh and SMPL utils
+ smpl = SMPL().to(args.device)
+ mesh_sampler = Mesh()
+
+ # Renderer for visualization
+ renderer = Renderer(faces=smpl.faces.cpu().numpy())
+
+ # Load model
+ trans_encoder = []
+
+ input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
+ hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
+ output_feat_dim = input_feat_dim[1:] + [3]
+
+ # which encoder block to have graph convs
+ which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
+
+ if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
+ # if only run eval, load checkpoint
+ logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
+ _model = torch.load(args.resume_checkpoint)
+ else:
+ # init three transformer-encoder blocks in a loop
+ for i in range(len(output_feat_dim)):
+ config_class, model_class = BertConfig, Graphormer
+ config = config_class.from_pretrained(args.config_name if args.config_name \
+ else args.model_name_or_path)
+
+ config.output_attentions = False
+ config.img_feature_dim = input_feat_dim[i]
+ config.output_feature_dim = output_feat_dim[i]
+ args.hidden_size = hidden_feat_dim[i]
+ args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
+
+ if which_blk_graph[i]==1:
+ config.graph_conv = True
+ logger.info("Add Graph Conv")
+ else:
+ config.graph_conv = False
+
+ config.mesh_type = args.mesh_type
+
+ # update model structure if specified in arguments
+ update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
+
+ for idx, param in enumerate(update_params):
+ arg_param = getattr(args, param)
+ config_param = getattr(config, param)
+ if arg_param > 0 and arg_param != config_param:
+ logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
+ setattr(config, param, arg_param)
+
+ # init a transformer encoder and append it to a list
+ assert config.hidden_size % config.num_attention_heads == 0
+ model = model_class(config=config)
+ logger.info("Init model from scratch.")
+ trans_encoder.append(model)
+
+ # init ImageNet pre-trained backbone model
+ if args.arch=='hrnet':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w40 model')
+ elif args.arch=='hrnet-w64':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w64 model')
+ else:
+ print("=> using pre-trained model '{}'".format(args.arch))
+ backbone = models.__dict__[args.arch](pretrained=True)
+ # remove the last fc layer
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
+
+
+ trans_encoder = torch.nn.Sequential(*trans_encoder)
+ total_params = sum(p.numel() for p in trans_encoder.parameters())
+ logger.info('Graphormer encoders total parameters: {}'.format(total_params))
+ backbone_total_params = sum(p.numel() for p in backbone.parameters())
+ logger.info('Backbone total parameters: {}'.format(backbone_total_params))
+
+ # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
+ _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
+
+ if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
+ # for fine-tuning or resume training or inference, load weights from checkpoint
+ logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
+ # workaround approach to load sparse tensor in graph conv.
+ states = torch.load(args.resume_checkpoint)
+ # states = checkpoint_loaded.state_dict()
+ for k, v in states.items():
+ states[k] = v.cpu()
+ # del checkpoint_loaded
+ _model.load_state_dict(states, strict=False)
+ del states
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # update configs to enable attention outputs
+ setattr(_model.trans_encoder[-1].config,'output_attentions', True)
+ setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
+ _model.trans_encoder[-1].bert.encoder.output_attentions = True
+ _model.trans_encoder[-1].bert.encoder.output_hidden_states = True
+ for iter_layer in range(4):
+ _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
+ for inter_block in range(3):
+ setattr(_model.trans_encoder[-1].config,'device', args.device)
+
+ _model.to(args.device)
+ logger.info("Run inference")
+
+ image_list = []
+ if not args.image_file_or_path:
+ raise ValueError("image_file_or_path not specified")
+ if op.isfile(args.image_file_or_path):
+ image_list = [args.image_file_or_path]
+ elif op.isdir(args.image_file_or_path):
+ # should be a path with images only
+ for filename in os.listdir(args.image_file_or_path):
+ if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
+ image_list.append(args.image_file_or_path+'/'+filename)
+ else:
+ raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
+
+ run_inference(args, image_list, _model, smpl, renderer, mesh_sampler)
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/src/custom_mesh_graphormer/tools/run_gphmer_handmesh.py b/src/custom_mesh_graphormer/tools/run_gphmer_handmesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..5971f7bb8f457f7f59943b1b98e3c4645f639398
--- /dev/null
+++ b/src/custom_mesh_graphormer/tools/run_gphmer_handmesh.py
@@ -0,0 +1,713 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Training and evaluation codes for
+3D hand mesh reconstruction from an image
+"""
+
+from __future__ import absolute_import, division, print_function
+import argparse
+import os
+import os.path as op
+import code
+import json
+import time
+import datetime
+import torch
+import torchvision.models as models
+from torchvision.utils import make_grid
+import gc
+import numpy as np
+import cv2
+from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
+from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
+from custom_mesh_graphormer.modeling._mano import MANO, Mesh
+from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
+from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
+from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
+import custom_mesh_graphormer.modeling.data.config as cfg
+from custom_mesh_graphormer.datasets.build import make_hand_data_loader
+
+from custom_mesh_graphormer.utils.logger import setup_logger
+from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
+from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed
+from custom_mesh_graphormer.utils.metric_logger import AverageMeter
+from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text
+from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error
+from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection
+
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+from azureml.core.run import Run
+aml_run = Run.get_context()
+
+def save_checkpoint(model, args, epoch, iteration, num_trial=10):
+ checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
+ epoch, iteration))
+ if not is_main_process():
+ return checkpoint_dir
+ mkdir(checkpoint_dir)
+ model_to_save = model.module if hasattr(model, 'module') else model
+ for i in range(num_trial):
+ try:
+ torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
+ torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
+ torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
+ logger.info("Save checkpoint to {}".format(checkpoint_dir))
+ break
+ except:
+ pass
+ else:
+ logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
+ return checkpoint_dir
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """
+ Sets the learning rate to the initial LR decayed by x every y epochs
+ x = 0.1, y = args.num_train_epochs/2.0 = 100
+ """
+ lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
+ """
+ Compute 2D reprojection loss if 2D keypoint annotations are available.
+ The confidence is binary and indicates whether the keypoints exist or not.
+ """
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
+ loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
+ return loss
+
+def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d):
+ """
+ Compute 3D keypoint loss if 3D keypoint annotations are available.
+ """
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
+ gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
+ conf = conf[has_pose_3d == 1]
+ pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
+ if len(gt_keypoints_3d) > 0:
+ gt_root = gt_keypoints_3d[:, 0,:]
+ gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :]
+ pred_root = pred_keypoints_3d[:, 0,:]
+ pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :]
+ return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
+ else:
+ return torch.FloatTensor(1).fill_(0.).to(device)
+
+def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
+ """
+ Compute per-vertex loss if vertex annotations are available.
+ """
+ pred_vertices_with_shape = pred_vertices[has_smpl == 1]
+ gt_vertices_with_shape = gt_vertices[has_smpl == 1]
+ if len(gt_vertices_with_shape) > 0:
+ return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
+ else:
+ return torch.FloatTensor(1).fill_(0.).to(device)
+
+
+def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
+
+ max_iter = len(train_dataloader)
+ iters_per_epoch = max_iter // args.num_train_epochs
+
+ optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
+ lr=args.lr,
+ betas=(0.9, 0.999),
+ weight_decay=0)
+
+ # define loss function (criterion) and optimizer
+ criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
+ criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
+ criterion_vertices = torch.nn.L1Loss().to(device)
+
+ if args.distributed:
+ Graphormer_model = torch.nn.parallel.DistributedDataParallel(
+ Graphormer_model, device_ids=[args.local_rank],
+ output_device=args.local_rank,
+ find_unused_parameters=True,
+ )
+
+ start_training_time = time.time()
+ end = time.time()
+ Graphormer_model.train()
+ batch_time = AverageMeter()
+ data_time = AverageMeter()
+ log_losses = AverageMeter()
+ log_loss_2djoints = AverageMeter()
+ log_loss_3djoints = AverageMeter()
+ log_loss_vertices = AverageMeter()
+
+ for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
+
+ Graphormer_model.train()
+ iteration += 1
+ epoch = iteration // iters_per_epoch
+ batch_size = images.size(0)
+ adjust_learning_rate(optimizer, epoch, args)
+ data_time.update(time.time() - end)
+
+ images = images.to(device)
+ gt_2d_joints = annotations['joints_2d'].to(device)
+ gt_pose = annotations['pose'].to(device)
+ gt_betas = annotations['betas'].to(device)
+ has_mesh = annotations['has_smpl'].to(device)
+ has_3d_joints = has_mesh
+ has_2d_joints = has_mesh
+ mjm_mask = annotations['mjm_mask'].to(device)
+ mvm_mask = annotations['mvm_mask'].to(device)
+
+ # generate mesh
+ gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas)
+ gt_vertices = gt_vertices/1000.0
+ gt_3d_joints = gt_3d_joints/1000.0
+
+ gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
+ # normalize gt based on hand's wrist
+ gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
+ gt_vertices = gt_vertices - gt_3d_root[:, None, :]
+ gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :]
+ gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :]
+ gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device)
+ gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints
+
+ # prepare masks for mask vertex/joint modeling
+ mjm_mask_ = mjm_mask.expand(-1,-1,2051)
+ mvm_mask_ = mvm_mask.expand(-1,-1,2051)
+ meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True)
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
+
+ # obtain 2d joints, which are projected from 3d joints of smpl mesh
+ pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
+ pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous())
+
+ # compute 3d joint loss (where the joints are directly output from transformer)
+ loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints)
+
+ # compute 3d vertex loss
+ loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \
+ args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) )
+
+ # compute 3d joint loss (where the joints are regressed from full mesh)
+ loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints)
+ # compute 2d joint loss
+ loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
+ keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints)
+
+ loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
+
+ # we empirically use hyperparameters to balance difference losses
+ loss = args.joints_loss_weight*loss_3d_joints + \
+ args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
+
+ # update logs
+ log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
+ log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
+ log_loss_vertices.update(loss_vertices.item(), batch_size)
+ log_losses.update(loss.item(), batch_size)
+
+ # back prop
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if iteration % args.logging_steps == 0 or iteration == max_iter:
+ eta_seconds = batch_time.avg * (max_iter - iteration)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ logger.info(
+ ' '.join(
+ ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
+ ).format(eta=eta_string, ep=epoch, iter=iteration,
+ memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+ + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
+ log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
+ optimizer.param_groups[0]['lr'])
+ )
+
+ aml_run.log(name='Loss', value=float(log_losses.avg))
+ aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
+ aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
+ aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
+
+ visual_imgs = visualize_mesh( renderer,
+ annotations['ori_img'].detach(),
+ annotations['joints_2d'].detach(),
+ pred_vertices.detach(),
+ pred_camera.detach(),
+ pred_2d_joints_from_mesh.detach())
+ visual_imgs = visual_imgs.transpose(0,1)
+ visual_imgs = visual_imgs.transpose(1,2)
+ visual_imgs = np.asarray(visual_imgs)
+
+ if is_main_process()==True:
+ stamp = str(epoch) + '_' + str(iteration)
+ temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
+ cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
+ aml_run.log_image(name='visual results', path=temp_fname)
+
+ if iteration % iters_per_epoch == 0:
+ if epoch%10==0:
+ checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
+
+ total_training_time = time.time() - start_training_time
+ total_time_str = str(datetime.timedelta(seconds=total_training_time))
+ logger.info('Total training time: {} ({:.4f} s / iter)'.format(
+ total_time_str, total_training_time / max_iter)
+ )
+ checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
+
+def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
+
+ criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
+ criterion_vertices = torch.nn.L1Loss().to(device)
+
+ if args.distributed:
+ Graphormer_model = torch.nn.parallel.DistributedDataParallel(
+ Graphormer_model, device_ids=[args.local_rank],
+ output_device=args.local_rank,
+ find_unused_parameters=True,
+ )
+ Graphormer_model.eval()
+
+ if args.aml_eval==True:
+ run_aml_inference_hand_mesh(args, val_dataloader,
+ Graphormer_model,
+ criterion_keypoints,
+ criterion_vertices,
+ 0,
+ mano_model, mesh_sampler,
+ renderer, split)
+ else:
+ run_inference_hand_mesh(args, val_dataloader,
+ Graphormer_model,
+ criterion_keypoints,
+ criterion_vertices,
+ 0,
+ mano_model, mesh_sampler,
+ renderer, split)
+ checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
+ return
+
+def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
+ # switch to evaluate mode
+ Graphormer_model.eval()
+ fname_output_save = []
+ mesh_output_save = []
+ joint_output_save = []
+ world_size = get_world_size()
+ with torch.no_grad():
+ for i, (img_keys, images, annotations) in enumerate(val_loader):
+ batch_size = images.size(0)
+ # compute output
+ images = images.to(device)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
+ # obtain 3d joints from full mesh
+ pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
+
+ for j in range(batch_size):
+ fname_output_save.append(img_keys[j])
+ pred_vertices_list = pred_vertices[j].tolist()
+ mesh_output_save.append(pred_vertices_list)
+ pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
+ joint_output_save.append(pred_3d_joints_from_mesh_list)
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ print('save results to pred.json')
+ output_json_file = 'pred.json'
+ print('save results to ', output_json_file)
+ with open(output_json_file, 'w') as f:
+ json.dump([joint_output_save, mesh_output_save], f)
+
+ azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1]
+ inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
+ output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip'
+
+ resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+ resolved_submit_cmd = 'rm %s'%(output_json_file)
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+ if world_size > 1:
+ torch.distributed.barrier()
+
+ return
+
+def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
+ # switch to evaluate mode
+ Graphormer_model.eval()
+ fname_output_save = []
+ mesh_output_save = []
+ joint_output_save = []
+ with torch.no_grad():
+ for i, (img_keys, images, annotations) in enumerate(val_loader):
+ batch_size = images.size(0)
+ # compute output
+ images = images.to(device)
+
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
+
+ # obtain 3d joints from full mesh
+ pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
+ pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
+ pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
+ pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
+
+ for j in range(batch_size):
+ fname_output_save.append(img_keys[j])
+ pred_vertices_list = pred_vertices[j].tolist()
+ mesh_output_save.append(pred_vertices_list)
+ pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
+ joint_output_save.append(pred_3d_joints_from_mesh_list)
+
+ if i%20==0:
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
+ # obtain 2d joints, which are projected from 3d joints of mesh
+ pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
+ visual_imgs = visualize_mesh( renderer,
+ annotations['ori_img'].detach(),
+ annotations['joints_2d'].detach(),
+ pred_vertices.detach(),
+ pred_camera.detach(),
+ pred_2d_joints_from_mesh.detach())
+
+ visual_imgs = visual_imgs.transpose(0,1)
+ visual_imgs = visual_imgs.transpose(1,2)
+ visual_imgs = np.asarray(visual_imgs)
+
+ inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
+ temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg'
+ cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
+
+ print('save results to pred.json')
+ with open('pred.json', 'w') as f:
+ json.dump([joint_output_save, mesh_output_save], f)
+
+ run_exp_name = args.resume_checkpoint.split('/')[-3]
+ run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1]
+ inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
+ resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json'
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+ resolved_submit_cmd = 'rm pred.json'
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+ return
+
+def visualize_mesh( renderer,
+ images,
+ gt_keypoints_2d,
+ pred_vertices,
+ pred_camera,
+ pred_keypoints_2d):
+ """Tensorboard logging."""
+ gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
+ to_lsp = list(range(21))
+ rend_imgs = []
+ batch_size = pred_vertices.shape[0]
+ # Do visualization for the first 6 images of the batch
+ for i in range(min(batch_size, 10)):
+ img = images[i].cpu().numpy().transpose(1,2,0)
+ # Get LSP keypoints from the full list of keypoints
+ gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
+ pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
+ # Get predict vertices for the particular example
+ vertices = pred_vertices[i].cpu().numpy()
+ cam = pred_camera[i].cpu().numpy()
+ # Visualize reconstruction and detected pose
+ rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
+ rend_img = rend_img.transpose(2,0,1)
+ rend_imgs.append(torch.from_numpy(rend_img))
+ rend_imgs = make_grid(rend_imgs, nrow=1)
+ return rend_imgs
+
+def visualize_mesh_test( renderer,
+ images,
+ gt_keypoints_2d,
+ pred_vertices,
+ pred_camera,
+ pred_keypoints_2d,
+ PAmPJPE):
+ """Tensorboard logging."""
+ gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
+ to_lsp = list(range(21))
+ rend_imgs = []
+ batch_size = pred_vertices.shape[0]
+ # Do visualization for the first 6 images of the batch
+ for i in range(min(batch_size, 10)):
+ img = images[i].cpu().numpy().transpose(1,2,0)
+ # Get LSP keypoints from the full list of keypoints
+ gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
+ pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
+ # Get predict vertices for the particular example
+ vertices = pred_vertices[i].cpu().numpy()
+ cam = pred_camera[i].cpu().numpy()
+ score = PAmPJPE[i]
+ # Visualize reconstruction and detected pose
+ rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
+ rend_img = rend_img.transpose(2,0,1)
+ rend_imgs.append(torch.from_numpy(rend_img))
+ rend_imgs = make_grid(rend_imgs, nrow=1)
+ return rend_imgs
+
+def visualize_mesh_no_text( renderer,
+ images,
+ pred_vertices,
+ pred_camera):
+ """Tensorboard logging."""
+ rend_imgs = []
+ batch_size = pred_vertices.shape[0]
+ # Do visualization for the first 6 images of the batch
+ for i in range(min(batch_size, 1)):
+ img = images[i].cpu().numpy().transpose(1,2,0)
+ # Get predict vertices for the particular example
+ vertices = pred_vertices[i].cpu().numpy()
+ cam = pred_camera[i].cpu().numpy()
+ # Visualize reconstruction only
+ rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand')
+ rend_img = rend_img.transpose(2,0,1)
+ rend_imgs.append(torch.from_numpy(rend_img))
+ rend_imgs = make_grid(rend_imgs, nrow=1)
+ return rend_imgs
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ #########################################################
+ # Data related arguments
+ #########################################################
+ parser.add_argument("--data_dir", default='datasets', type=str, required=False,
+ help="Directory with all datasets, each in one subfolder")
+ parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
+ help="Yaml file with all data for training.")
+ parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
+ help="Yaml file with all data for validation.")
+ parser.add_argument("--num_workers", default=4, type=int,
+ help="Workers in dataloader.")
+ parser.add_argument("--img_scale_factor", default=1, type=int,
+ help="adjust image resolution.")
+ #########################################################
+ # Loading/saving checkpoints
+ #########################################################
+ parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
+ help="Path to pre-trained transformer model or model type.")
+ parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
+ help="Path to specific checkpoint for resume training.")
+ parser.add_argument("--output_dir", default='output/', type=str, required=False,
+ help="The output directory to save checkpoint and test results.")
+ parser.add_argument("--config_name", default="", type=str,
+ help="Pretrained config name or path if not the same as model_name.")
+ parser.add_argument('-a', '--arch', default='hrnet-w64',
+ help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
+ #########################################################
+ # Training parameters
+ #########################################################
+ parser.add_argument("--per_gpu_train_batch_size", default=64, type=int,
+ help="Batch size per GPU/CPU for training.")
+ parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int,
+ help="Batch size per GPU/CPU for evaluation.")
+ parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
+ help="The initial lr.")
+ parser.add_argument("--num_train_epochs", default=200, type=int,
+ help="Total number of training epochs to perform.")
+ parser.add_argument("--vertices_loss_weight", default=1.0, type=float)
+ parser.add_argument("--joints_loss_weight", default=1.0, type=float)
+ parser.add_argument("--vloss_w_full", default=0.5, type=float)
+ parser.add_argument("--vloss_w_sub", default=0.5, type=float)
+ parser.add_argument("--drop_out", default=0.1, type=float,
+ help="Drop out ratio in BERT.")
+ #########################################################
+ # Model architectures
+ #########################################################
+ parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--hidden_size", default=-1, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--num_attention_heads", default=-1, type=int, required=False,
+ help="Update model config if given. Note that the division of "
+ "hidden_size / num_attention_heads should be in integer.")
+ parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
+ help="Update model config if given.")
+ parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--which_gcn", default='0,0,1', type=str,
+ help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
+ parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
+
+ #########################################################
+ # Others
+ #########################################################
+ parser.add_argument("--run_eval_only", default=False, action='store_true',)
+ parser.add_argument("--multiscale_inference", default=False, action='store_true',)
+ # if enable "multiscale_inference", dataloader will apply transformations to the test image based on
+ # the rotation "rot" and scale "sc" parameters below
+ parser.add_argument("--rot", default=0, type=float)
+ parser.add_argument("--sc", default=1.0, type=float)
+ parser.add_argument("--aml_eval", default=False, action='store_true',)
+
+ parser.add_argument('--logging_steps', type=int, default=100,
+ help="Log every X steps.")
+ parser.add_argument("--device", type=str, default='cuda',
+ help="cuda or cpu")
+ parser.add_argument('--seed', type=int, default=88,
+ help="random seed for initialization.")
+ parser.add_argument("--local_rank", type=int, default=0,
+ help="For distributed training.")
+ args = parser.parse_args()
+ return args
+
+def main(args):
+ global logger
+ # Setup CUDA, GPU & distributed training
+ args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
+ os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
+ print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
+
+ args.distributed = args.num_gpus > 1
+ args.device = torch.device(args.device)
+ if args.distributed:
+ print("Init distributed training on local rank {}".format(args.local_rank))
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(
+ backend='nccl', init_method='env://'
+ )
+ synchronize()
+
+ mkdir(args.output_dir)
+ logger = setup_logger("Graphormer", args.output_dir, get_rank())
+ set_seed(args.seed, args.num_gpus)
+ logger.info("Using {} GPUs".format(args.num_gpus))
+
+ # Mesh and SMPL utils
+ mano_model = MANO().to(args.device)
+ mano_model.layer = mano_model.layer.to(device)
+ mesh_sampler = Mesh()
+
+ # Renderer for visualization
+ renderer = Renderer(faces=mano_model.face)
+
+ # Load pretrained model
+ trans_encoder = []
+
+ input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
+ hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
+ output_feat_dim = input_feat_dim[1:] + [3]
+
+ # which encoder block to have graph convs
+ which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
+
+ if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
+ # if only run eval, load checkpoint
+ logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
+ _model = torch.load(args.resume_checkpoint)
+
+ else:
+ # init three transformer-encoder blocks in a loop
+ for i in range(len(output_feat_dim)):
+ config_class, model_class = BertConfig, Graphormer
+ config = config_class.from_pretrained(args.config_name if args.config_name \
+ else args.model_name_or_path)
+
+ config.output_attentions = False
+ config.hidden_dropout_prob = args.drop_out
+ config.img_feature_dim = input_feat_dim[i]
+ config.output_feature_dim = output_feat_dim[i]
+ args.hidden_size = hidden_feat_dim[i]
+ args.intermediate_size = int(args.hidden_size*2)
+
+ if which_blk_graph[i]==1:
+ config.graph_conv = True
+ logger.info("Add Graph Conv")
+ else:
+ config.graph_conv = False
+
+ config.mesh_type = args.mesh_type
+
+ # update model structure if specified in arguments
+ update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
+ for idx, param in enumerate(update_params):
+ arg_param = getattr(args, param)
+ config_param = getattr(config, param)
+ if arg_param > 0 and arg_param != config_param:
+ logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
+ setattr(config, param, arg_param)
+
+ # init a transformer encoder and append it to a list
+ assert config.hidden_size % config.num_attention_heads == 0
+ model = model_class(config=config)
+ logger.info("Init model from scratch.")
+ trans_encoder.append(model)
+
+ # create backbone model
+ if args.arch=='hrnet':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w40 model')
+ elif args.arch=='hrnet-w64':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w64 model')
+ else:
+ print("=> using pre-trained model '{}'".format(args.arch))
+ backbone = models.__dict__[args.arch](pretrained=True)
+ # remove the last fc layer
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
+
+ trans_encoder = torch.nn.Sequential(*trans_encoder)
+ total_params = sum(p.numel() for p in trans_encoder.parameters())
+ logger.info('Graphormer encoders total parameters: {}'.format(total_params))
+ backbone_total_params = sum(p.numel() for p in backbone.parameters())
+ logger.info('Backbone total parameters: {}'.format(backbone_total_params))
+
+ # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
+ _model = Graphormer_Network(args, config, backbone, trans_encoder)
+
+ if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
+ # for fine-tuning or resume training or inference, load weights from checkpoint
+ logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
+ # workaround approach to load sparse tensor in graph conv.
+ state_dict = torch.load(args.resume_checkpoint)
+ _model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ _model.to(args.device)
+ logger.info("Training parameters %s", args)
+
+ if args.run_eval_only==True:
+ val_dataloader = make_hand_data_loader(args, args.val_yaml,
+ args.distributed, is_train=False, scale_factor=args.img_scale_factor)
+ run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler)
+
+ else:
+ train_dataloader = make_hand_data_loader(args, args.train_yaml,
+ args.distributed, is_train=True, scale_factor=args.img_scale_factor)
+ run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler)
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/src/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py b/src/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7300a77dee4cfc11160abfbc19e2acf89d26ed
--- /dev/null
+++ b/src/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py
@@ -0,0 +1,338 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+End-to-end inference codes for
+3D hand mesh reconstruction from an image
+"""
+
+from __future__ import absolute_import, division, print_function
+import argparse
+import os
+import os.path as op
+import code
+import json
+import time
+import datetime
+import torch
+import torchvision.models as models
+from torchvision.utils import make_grid
+import gc
+import numpy as np
+import cv2
+from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
+from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
+from custom_mesh_graphormer.modeling._mano import MANO, Mesh
+from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
+from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
+from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
+import custom_mesh_graphormer.modeling.data.config as cfg
+from custom_mesh_graphormer.datasets.build import make_hand_data_loader
+
+from custom_mesh_graphormer.utils.logger import setup_logger
+from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
+from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed
+from custom_mesh_graphormer.utils.metric_logger import AverageMeter
+from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
+from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error
+from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection
+
+from PIL import Image
+from torchvision import transforms
+
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+transform = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])])
+
+transform_visualize = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor()])
+
+def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler):
+# switch to evaluate mode
+ Graphormer_model.eval()
+ mano.eval()
+ with torch.no_grad():
+ for image_file in image_list:
+ if 'pred' not in image_file:
+ att_all = []
+ print(image_file)
+ img = Image.open(image_file)
+ img_tensor = transform(img)
+ img_visual = transform_visualize(img)
+
+ batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
+ batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
+ # forward-pass
+ pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
+ # obtain 3d joints from full mesh
+ pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
+ pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
+ pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
+ pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
+
+ # save attantion
+ att_max_value = att[-1]
+ att_cpu = np.asarray(att_max_value.cpu().detach())
+ att_all.append(att_cpu)
+
+ # obtain 3d joints, which are regressed from the full mesh
+ pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
+ # obtain 2d joints, which are projected from 3d joints of mesh
+ pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
+ pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
+
+
+ visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
+ pred_vertices[0].detach(),
+ pred_camera.detach())
+ # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
+ # pred_vertices[0].detach(),
+ # pred_vertices_sub[0].detach(),
+ # pred_2d_coarse_vertices_from_mesh[0].detach(),
+ # pred_2d_joints_from_mesh[0].detach(),
+ # pred_camera.detach(),
+ # att[-1][0].detach())
+ visual_imgs = visual_imgs_output.transpose(1,2,0)
+ visual_imgs = np.asarray(visual_imgs)
+
+ temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
+ print('save to ', temp_fname)
+ cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
+ return
+
+def visualize_mesh( renderer, images,
+ pred_vertices_full,
+ pred_camera):
+ img = images.cpu().numpy().transpose(1,2,0)
+ # Get predict vertices for the particular example
+ vertices_full = pred_vertices_full.cpu().numpy()
+ cam = pred_camera.cpu().numpy()
+ # Visualize only mesh reconstruction
+ rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
+ rend_img = rend_img.transpose(2,0,1)
+ return rend_img
+
+def visualize_mesh_and_attention( renderer, images,
+ pred_vertices_full,
+ pred_vertices,
+ pred_2d_vertices,
+ pred_2d_joints,
+ pred_camera,
+ attention):
+ img = images.cpu().numpy().transpose(1,2,0)
+ # Get predict vertices for the particular example
+ vertices_full = pred_vertices_full.cpu().numpy()
+ vertices = pred_vertices.cpu().numpy()
+ vertices_2d = pred_2d_vertices.cpu().numpy()
+ joints_2d = pred_2d_joints.cpu().numpy()
+ cam = pred_camera.cpu().numpy()
+ att = attention.cpu().numpy()
+ # Visualize reconstruction and attention
+ rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
+ rend_img = rend_img.transpose(2,0,1)
+ return rend_img
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ #########################################################
+ # Data related arguments
+ #########################################################
+ parser.add_argument("--num_workers", default=4, type=int,
+ help="Workers in dataloader.")
+ parser.add_argument("--img_scale_factor", default=1, type=int,
+ help="adjust image resolution.")
+ parser.add_argument("--image_file_or_path", default='./samples/hand', type=str,
+ help="test data")
+ #########################################################
+ # Loading/saving checkpoints
+ #########################################################
+ parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
+ help="Path to pre-trained transformer model or model type.")
+ parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
+ help="Path to specific checkpoint for resume training.")
+ parser.add_argument("--output_dir", default='output/', type=str, required=False,
+ help="The output directory to save checkpoint and test results.")
+ parser.add_argument("--config_name", default="", type=str,
+ help="Pretrained config name or path if not the same as model_name.")
+ parser.add_argument('-a', '--arch', default='hrnet-w64',
+ help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
+ #########################################################
+ # Model architectures
+ #########################################################
+ parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--hidden_size", default=-1, type=int, required=False,
+ help="Update model config if given")
+ parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
+ help="Update model config if given. Note that the division of "
+ "hidden_size / num_attention_heads should be in integer.")
+ parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
+ help="Update model config if given.")
+ parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
+ help="The Image Feature Dimension.")
+ parser.add_argument("--which_gcn", default='0,0,1', type=str,
+ help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
+ parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
+
+ #########################################################
+ # Others
+ #########################################################
+ parser.add_argument("--run_eval_only", default=True, action='store_true',)
+ parser.add_argument("--device", type=str, default='cuda',
+ help="cuda or cpu")
+ parser.add_argument('--seed', type=int, default=88,
+ help="random seed for initialization.")
+ args = parser.parse_args()
+ return args
+
+def main(args):
+ global logger
+ # Setup CUDA, GPU & distributed training
+ args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
+ os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
+ print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
+
+ mkdir(args.output_dir)
+ logger = setup_logger("Graphormer", args.output_dir, get_rank())
+ set_seed(args.seed, args.num_gpus)
+ logger.info("Using {} GPUs".format(args.num_gpus))
+
+ # Mesh and MANO utils
+ mano_model = MANO().to(args.device)
+ mano_model.layer = mano_model.layer.to(device)
+ mesh_sampler = Mesh()
+
+ # Renderer for visualization
+ renderer = Renderer(faces=mano_model.face)
+
+ # Load pretrained model
+ trans_encoder = []
+
+ input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
+ hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
+ output_feat_dim = input_feat_dim[1:] + [3]
+
+ # which encoder block to have graph convs
+ which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
+
+ if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
+ # if only run eval, load checkpoint
+ logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
+ _model = torch.load(args.resume_checkpoint)
+
+ else:
+ # init three transformer-encoder blocks in a loop
+ for i in range(len(output_feat_dim)):
+ config_class, model_class = BertConfig, Graphormer
+ config = config_class.from_pretrained(args.config_name if args.config_name \
+ else args.model_name_or_path)
+
+ config.output_attentions = False
+ config.img_feature_dim = input_feat_dim[i]
+ config.output_feature_dim = output_feat_dim[i]
+ args.hidden_size = hidden_feat_dim[i]
+ args.intermediate_size = int(args.hidden_size*2)
+
+ if which_blk_graph[i]==1:
+ config.graph_conv = True
+ logger.info("Add Graph Conv")
+ else:
+ config.graph_conv = False
+
+ config.mesh_type = args.mesh_type
+
+ # update model structure if specified in arguments
+ update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
+ for idx, param in enumerate(update_params):
+ arg_param = getattr(args, param)
+ config_param = getattr(config, param)
+ if arg_param > 0 and arg_param != config_param:
+ logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
+ setattr(config, param, arg_param)
+
+ # init a transformer encoder and append it to a list
+ assert config.hidden_size % config.num_attention_heads == 0
+ model = model_class(config=config)
+ logger.info("Init model from scratch.")
+ trans_encoder.append(model)
+
+ # create backbone model
+ if args.arch=='hrnet':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w40 model')
+ elif args.arch=='hrnet-w64':
+ hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
+ hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
+ hrnet_update_config(hrnet_config, hrnet_yaml)
+ backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
+ logger.info('=> loading hrnet-v2-w64 model')
+ else:
+ print("=> using pre-trained model '{}'".format(args.arch))
+ backbone = models.__dict__[args.arch](pretrained=True)
+ # remove the last fc layer
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
+
+ trans_encoder = torch.nn.Sequential(*trans_encoder)
+ total_params = sum(p.numel() for p in trans_encoder.parameters())
+ logger.info('Graphormer encoders total parameters: {}'.format(total_params))
+ backbone_total_params = sum(p.numel() for p in backbone.parameters())
+ logger.info('Backbone total parameters: {}'.format(backbone_total_params))
+
+ # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
+ _model = Graphormer_Network(args, config, backbone, trans_encoder)
+
+ if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
+ # for fine-tuning or resume training or inference, load weights from checkpoint
+ logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
+ # workaround approach to load sparse tensor in graph conv.
+ state_dict = torch.load(args.resume_checkpoint)
+ _model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # update configs to enable attention outputs
+ setattr(_model.trans_encoder[-1].config,'output_attentions', True)
+ setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
+ _model.trans_encoder[-1].bert.encoder.output_attentions = True
+ _model.trans_encoder[-1].bert.encoder.output_hidden_states = True
+ for iter_layer in range(4):
+ _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
+ for inter_block in range(3):
+ setattr(_model.trans_encoder[-1].config,'device', args.device)
+
+ _model.to(args.device)
+ logger.info("Run inference")
+
+ image_list = []
+ if not args.image_file_or_path:
+ raise ValueError("image_file_or_path not specified")
+ if op.isfile(args.image_file_or_path):
+ image_list = [args.image_file_or_path]
+ elif op.isdir(args.image_file_or_path):
+ # should be a path with images only
+ for filename in os.listdir(args.image_file_or_path):
+ if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
+ image_list.append(args.image_file_or_path+'/'+filename)
+ else:
+ raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
+
+ run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler)
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/src/custom_mesh_graphormer/tools/run_hand_multiscale.py b/src/custom_mesh_graphormer/tools/run_hand_multiscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..d656756b7c347d0c55bedf8b359f89b5f60ee2fb
--- /dev/null
+++ b/src/custom_mesh_graphormer/tools/run_hand_multiscale.py
@@ -0,0 +1,136 @@
+from __future__ import absolute_import, division, print_function
+
+import argparse
+import os
+import os.path as op
+import code
+import json
+import zipfile
+import torch
+import numpy as np
+from custom_mesh_graphormer.utils.metric_pampjpe import get_alignMesh
+
+
+def load_pred_json(filepath):
+ archive = zipfile.ZipFile(filepath, 'r')
+ jsondata = archive.read('pred.json')
+ reference = json.loads(jsondata.decode("utf-8"))
+ return reference[0], reference[1]
+
+
+def multiscale_fusion(output_dir):
+ s = '10'
+ filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
+ ref_joints, ref_vertices = load_pred_json(filepath)
+ ref_joints_array = np.asarray(ref_joints)
+ ref_vertices_array = np.asarray(ref_vertices)
+
+ rotations = [0.0]
+ for i in range(1,10):
+ rotations.append(i*10)
+ rotations.append(i*-10)
+
+ scale = [0.7,0.8,0.9,1.0,1.1]
+ multiscale_joints = []
+ multiscale_vertices = []
+
+ counter = 0
+ for s in scale:
+ for r in rotations:
+ setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
+ filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
+ joints, vertices = load_pred_json(filepath)
+ joints_array = np.asarray(joints)
+ vertices_array = np.asarray(vertices)
+
+ pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
+ pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
+ print('--------------------------')
+ print('scale:', s, 'rotate', r)
+ print('PAMPJPE:', 1000*np.mean(pa_joint_error))
+ print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
+ multiscale_joints.append(pa_joint_array)
+ multiscale_vertices.append(pa_vertices_array)
+ counter = counter + 1
+
+ overall_joints_array = ref_joints_array.copy()
+ overall_vertices_array = ref_vertices_array.copy()
+ for i in range(counter):
+ overall_joints_array += multiscale_joints[i]
+ overall_vertices_array += multiscale_vertices[i]
+
+ overall_joints_array /= (1+counter)
+ overall_vertices_array /= (1+counter)
+ pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
+ pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
+ print('--------------------------')
+ print('overall:')
+ print('PAMPJPE:', 1000*np.mean(pa_joint_error))
+ print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
+
+ joint_output_save = overall_joints_array.tolist()
+ mesh_output_save = overall_vertices_array.tolist()
+
+ print('save results to pred.json')
+ with open('pred.json', 'w') as f:
+ json.dump([joint_output_save, mesh_output_save], f)
+
+
+ filepath = output_dir+'ckpt200-multisc-pred.zip'
+ resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+ resolved_submit_cmd = 'rm pred.json'
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+
+
+def run_multiscale_inference(model_path, mode, output_dir):
+
+ if mode==True:
+ rotations = [0.0]
+ for i in range(1,10):
+ rotations.append(i*10)
+ rotations.append(i*-10)
+ scale = [0.7,0.8,0.9,1.0,1.1]
+ else:
+ rotations = [0.0]
+ scale = [1.0]
+
+ job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
+ "--val_yaml freihand_v3/test.yaml " \
+ "--resume_checkpoint %s " \
+ "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
+ "--multiscale_inference " \
+ "--rot %f " \
+ "--sc %s " \
+ "--arch hrnet-w64 " \
+ "--num_hidden_layers 4 " \
+ "--num_attention_heads 4 " \
+ "--input_feat_dim 2051,512,128 " \
+ "--hidden_feat_dim 1024,256,64 " \
+ "--output_dir %s"
+
+ for s in scale:
+ for r in rotations:
+ resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
+ print(resolved_submit_cmd)
+ os.system(resolved_submit_cmd)
+
+def main(args):
+ model_path = args.model_path
+ mode = args.multiscale_inference
+ output_dir = args.output_dir
+ run_multiscale_inference(model_path, mode, output_dir)
+ if mode==True:
+ multiscale_fusion(output_dir)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
+ parser.add_argument("--model_path")
+ parser.add_argument("--multiscale_inference", default=False, action='store_true',)
+ parser.add_argument("--output_dir", default='output/', type=str, required=False,
+ help="The output directory to save checkpoint and test results.")
+ args = parser.parse_args()
+ main(args)
diff --git a/src/custom_mesh_graphormer/utils/__init__.py b/src/custom_mesh_graphormer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/custom_mesh_graphormer/utils/comm.py b/src/custom_mesh_graphormer/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae05350639005b43ee8c96803d0d545442a24968
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/comm.py
@@ -0,0 +1,176 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import pickle
+import time
+
+import torch
+import torch.distributed as dist
+
+from comfy.model_management import get_torch_device
+device = get_torch_device()
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+def gather_on_master(data):
+ """Same as all_gather, but gathers data on master process only, using CPU.
+ Thus, this does not work with NCCL backend unless they add CPU support.
+
+ The memory consumption of this function is ~ 3x of data size. While in
+ principal, it should be ~2x, it's not easy to force Python to release
+ memory immediately and thus, peak memory usage could be up to 3x.
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ # trying to optimize memory, but in fact, it's not guaranteed to be released
+ del data
+ storage = torch.ByteStorage.from_buffer(buffer)
+ del buffer
+ tensor = torch.ByteTensor(storage)
+
+ # obtain Tensor size of each rank
+ local_size = torch.LongTensor([tensor.numel()])
+ size_list = [torch.LongTensor([0]) for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,))
+ tensor = torch.cat((tensor, padding), dim=0)
+ del padding
+
+ if is_main_process():
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)))
+ dist.gather(tensor, gather_list=tensor_list, dst=0)
+ del tensor
+ else:
+ dist.gather(tensor, gather_list=[], dst=0)
+ del tensor
+ return
+
+ data_list = []
+ for tensor in tensor_list:
+ buffer = tensor.cpu().numpy().tobytes()
+ del tensor
+ data_list.append(pickle.loads(buffer))
+ del buffer
+
+ return data_list
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device)
+
+ # obtain Tensor size of each rank
+ local_size = torch.LongTensor([tensor.numel()]).to(device)
+ size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/src/custom_mesh_graphormer/utils/dataset_utils.py b/src/custom_mesh_graphormer/utils/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb66451e456e4fd8591ef83b34db1d9192cb23cc
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/dataset_utils.py
@@ -0,0 +1,66 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+
+
+import os
+import os.path as op
+import numpy as np
+import base64
+import cv2
+import yaml
+from collections import OrderedDict
+
+
+def img_from_base64(imagestring):
+ try:
+ jpgbytestring = base64.b64decode(imagestring)
+ nparr = np.frombuffer(jpgbytestring, np.uint8)
+ r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
+ return r
+ except:
+ return None
+
+
+def load_labelmap(labelmap_file):
+ label_dict = None
+ if labelmap_file is not None and op.isfile(labelmap_file):
+ label_dict = OrderedDict()
+ with open(labelmap_file, 'r') as fp:
+ for line in fp:
+ label = line.strip().split('\t')[0]
+ if label in label_dict:
+ raise ValueError("Duplicate label " + label + " in labelmap.")
+ else:
+ label_dict[label] = len(label_dict)
+ return label_dict
+
+
+def load_shuffle_file(shuf_file):
+ shuf_list = None
+ if shuf_file is not None:
+ with open(shuf_file, 'r') as fp:
+ shuf_list = []
+ for i in fp:
+ shuf_list.append(int(i.strip()))
+ return shuf_list
+
+
+def load_box_shuffle_file(shuf_file):
+ if shuf_file is not None:
+ with open(shuf_file, 'r') as fp:
+ img_shuf_list = []
+ box_shuf_list = []
+ for i in fp:
+ idx = [int(_) for _ in i.strip().split('\t')]
+ img_shuf_list.append(idx[0])
+ box_shuf_list.append(idx[1])
+ return [img_shuf_list, box_shuf_list]
+ return None
+
+
+def load_from_yaml_file(file_name):
+ with open(file_name, 'r') as fp:
+ return yaml.load(fp, Loader=yaml.CLoader)
diff --git a/src/custom_mesh_graphormer/utils/geometric_layers.py b/src/custom_mesh_graphormer/utils/geometric_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4bf9d10eb089e0887a38501dc8b78f36fc2553
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/geometric_layers.py
@@ -0,0 +1,58 @@
+"""
+Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
+
+Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
+"""
+import torch
+
+def rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat2mat(quat)
+
+def quat2mat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def orthographic_projection(X, camera):
+ """Perform orthographic projection of 3D points X using the camera parameters
+ Args:
+ X: size = [B, N, 3]
+ camera: size = [B, 3]
+ Returns:
+ Projected 2D points -- size = [B, N, 2]
+ """
+ camera = camera.view(-1, 1, 3)
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
+ shape = X_trans.shape
+ X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
+ return X_2d
diff --git a/src/custom_mesh_graphormer/utils/image_ops.py b/src/custom_mesh_graphormer/utils/image_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb62d8fa2503a24acaa61c7df8cc999f195dfa11
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/image_ops.py
@@ -0,0 +1,208 @@
+"""
+Image processing tools
+
+Modified from open source projects:
+(https://github.com/nkolot/GraphCMR/)
+(https://github.com/open-mmlab/mmdetection)
+
+"""
+
+import numpy as np
+import base64
+import cv2
+import torch
+import scipy.misc
+
+def img_from_base64(imagestring):
+ try:
+ jpgbytestring = base64.b64decode(imagestring)
+ nparr = np.frombuffer(jpgbytestring, np.uint8)
+ r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
+ return r
+ except ValueError:
+ return None
+
+def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
+ return rotated
+
+def myimresize(img, size, return_scale=False, interpolation='bilinear'):
+
+ h, w = img.shape[:2]
+ resized_img = cv2.resize(
+ img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def get_transform(center, scale, res, rot=0):
+ """Generate transformation matrix."""
+ h = 200 * scale
+ t = np.zeros((3, 3))
+ t[0, 0] = float(res[1]) / h
+ t[1, 1] = float(res[0]) / h
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
+ t[2, 2] = 1
+ if not rot == 0:
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3,3))
+ rot_rad = rot * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ rot_mat[2,2] = 1
+ # Need to rotate around center
+ t_mat = np.eye(3)
+ t_mat[0,2] = -res[1]/2
+ t_mat[1,2] = -res[0]/2
+ t_inv = t_mat.copy()
+ t_inv[:2,2] *= -1
+ t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
+ return t
+
+def transform(pt, center, scale, res, invert=0, rot=0):
+ """Transform pixel location to different reference."""
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ # t = np.linalg.inv(t)
+ t_torch = torch.from_numpy(t)
+ t_torch = torch.inverse(t_torch)
+ t = t_torch.numpy()
+ new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
+ new_pt = np.dot(t, new_pt)
+ return new_pt[:2].astype(int)+1
+
+def crop(img, center, scale, res, rot=0):
+ """Crop image according to the supplied bounding box."""
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,
+ res[1]+1], center, scale, res, invert=1))-1
+ # Padding so that when rotated proper amount of context is included
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
+ if not rot == 0:
+ ul -= pad
+ br += pad
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(new_shape)
+
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
+ old_y = max(0, ul[1]), min(len(img), br[1])
+
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+ old_x[0]:old_x[1]]
+ if not rot == 0:
+ # Remove padding
+ # new_img = scipy.misc.imrotate(new_img, rot)
+ new_img = myimrotate(new_img, rot)
+ new_img = new_img[pad:-pad, pad:-pad]
+
+ # new_img = scipy.misc.imresize(new_img, res)
+ new_img = myimresize(new_img, [res[0], res[1]])
+ return new_img
+
+def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
+ """'Undo' the image cropping/resizing.
+ This function is used when evaluating mask/part segmentation.
+ """
+ res = img.shape[:2]
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
+ # Bottom right point
+ br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
+ # size of cropped image
+ crop_shape = [br[1] - ul[1], br[0] - ul[0]]
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(orig_shape, dtype=np.uint8)
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(orig_shape[1], br[0])
+ old_y = max(0, ul[1]), min(orig_shape[0], br[1])
+ # img = scipy.misc.imresize(img, crop_shape, interp='nearest')
+ img = myimresize(img, [crop_shape[0],crop_shape[1]])
+ new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
+ return new_img
+
+def rot_aa(aa, rot):
+ """Rotate axis angle parameters."""
+ # pose parameters
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
+ [0, 0, 1]])
+ # find the rotation of the body in camera frame
+ per_rdg, _ = cv2.Rodrigues(aa)
+ # apply the global rotation to the global orientation
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
+ aa = (resrot.T)[0]
+ return aa
+
+def flip_img(img):
+ """Flip rgb images or masks.
+ channels come last, e.g. (256,256,3).
+ """
+ img = np.fliplr(img)
+ return img
+
+def flip_kp(kp):
+ """Flip keypoints."""
+ flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
+ kp = kp[flipped_parts]
+ kp[:,0] = - kp[:,0]
+ return kp
+
+def flip_pose(pose):
+ """Flip pose.
+ The flipping is based on SMPL parameters.
+ """
+ flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
+ 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
+ 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
+ 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
+ 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
+ pose = pose[flippedParts]
+ # we also negate the second and the third dimension of the axis-angle
+ pose[1::3] = -pose[1::3]
+ pose[2::3] = -pose[2::3]
+ return pose
+
+def flip_aa(aa):
+ """Flip axis-angle representation.
+ We negate the second and the third dimension of the axis-angle.
+ """
+ aa[1] = -aa[1]
+ aa[2] = -aa[2]
+ return aa
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/utils/logger.py b/src/custom_mesh_graphormer/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..013139605fc60634265b269d5260d4075e2b6478
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/logger.py
@@ -0,0 +1,100 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import os
+import sys
+from logging import StreamHandler, Handler, getLevelName
+
+
+# this class is a copy of logging.FileHandler except we end self.close()
+# at the end of each emit. While closing file and reopening file after each
+# write is not efficient, it allows us to see partial logs when writing to
+# fused Azure blobs, which is very convenient
+class FileHandler(StreamHandler):
+ """
+ A handler class which writes formatted logging records to disk files.
+ """
+ def __init__(self, filename, mode='a', encoding=None, delay=False):
+ """
+ Open the specified file and use it as the stream for logging.
+ """
+ # Issue #27493: add support for Path objects to be passed in
+ filename = os.fspath(filename)
+ #keep the absolute path, otherwise derived classes which use this
+ #may come a cropper when the current directory changes
+ self.baseFilename = os.path.abspath(filename)
+ self.mode = mode
+ self.encoding = encoding
+ self.delay = delay
+ if delay:
+ #We don't open the stream, but we still need to call the
+ #Handler constructor to set level, formatter, lock etc.
+ Handler.__init__(self)
+ self.stream = None
+ else:
+ StreamHandler.__init__(self, self._open())
+
+ def close(self):
+ """
+ Closes the stream.
+ """
+ self.acquire()
+ try:
+ try:
+ if self.stream:
+ try:
+ self.flush()
+ finally:
+ stream = self.stream
+ self.stream = None
+ if hasattr(stream, "close"):
+ stream.close()
+ finally:
+ # Issue #19523: call unconditionally to
+ # prevent a handler leak when delay is set
+ StreamHandler.close(self)
+ finally:
+ self.release()
+
+ def _open(self):
+ """
+ Open the current base file with the (original) mode and encoding.
+ Return the resulting stream.
+ """
+ return open(self.baseFilename, self.mode, encoding=self.encoding)
+
+ def emit(self, record):
+ """
+ Emit a record.
+
+ If the stream was not opened because 'delay' was specified in the
+ constructor, open it before calling the superclass's emit.
+ """
+ if self.stream is None:
+ self.stream = self._open()
+ StreamHandler.emit(self, record)
+ self.close()
+
+ def __repr__(self):
+ level = getLevelName(self.level)
+ return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
+
+
+def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ # don't log results for the non-master process
+ if distributed_rank > 0:
+ return logger
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ if save_dir:
+ fh = FileHandler(os.path.join(save_dir, filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+
+ return logger
diff --git a/src/custom_mesh_graphormer/utils/metric_logger.py b/src/custom_mesh_graphormer/utils/metric_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddaa0ab3ac95314cb94b5cb8c2a76a838f141a59
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/metric_logger.py
@@ -0,0 +1,45 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Basic logger. It Computes and stores the average and current value
+"""
+
+class AverageMeter(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+
+class EvalMetricsLogger(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ # define a upper-bound performance (worst case)
+ # numbers are in unit millimeter
+ self.PAmPJPE = 100.0/1000.0
+ self.mPJPE = 100.0/1000.0
+ self.mPVE = 100.0/1000.0
+
+ self.epoch = 0
+
+ def update(self, mPVE, mPJPE, PAmPJPE, epoch):
+ self.PAmPJPE = PAmPJPE
+ self.mPJPE = mPJPE
+ self.mPVE = mPVE
+ self.epoch = epoch
diff --git a/src/custom_mesh_graphormer/utils/metric_pampjpe.py b/src/custom_mesh_graphormer/utils/metric_pampjpe.py
new file mode 100644
index 0000000000000000000000000000000000000000..89fe55b7f8d0973670f3cc141a666528dd19ebd1
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/metric_pampjpe.py
@@ -0,0 +1,99 @@
+"""
+Functions for compuing Procrustes alignment and reconstruction error
+
+Parts of the code are adapted from https://github.com/akanazawa/hmr
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+
+def compute_similarity_transform(S1, S2):
+ """Computes a similarity transform (sR, t) that takes
+ a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
+ where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
+ i.e. solves the orthogonal Procrutes problem.
+ """
+ transposed = False
+ if S1.shape[0] != 3 and S1.shape[0] != 2:
+ S1 = S1.T
+ S2 = S2.T
+ transposed = True
+ assert(S2.shape[1] == S1.shape[1])
+
+ # 1. Remove mean.
+ mu1 = S1.mean(axis=1, keepdims=True)
+ mu2 = S2.mean(axis=1, keepdims=True)
+ X1 = S1 - mu1
+ X2 = S2 - mu2
+
+ # 2. Compute variance of X1 used for scale.
+ var1 = np.sum(X1**2)
+
+ # 3. The outer product of X1 and X2.
+ K = X1.dot(X2.T)
+
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
+ # singular vectors of K.
+ U, s, Vh = np.linalg.svd(K)
+ V = Vh.T
+ # Construct Z that fixes the orientation of R to get det(R)=1.
+ Z = np.eye(U.shape[0])
+ Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
+ # Construct R.
+ R = V.dot(Z.dot(U.T))
+
+ # 5. Recover scale.
+ scale = np.trace(R.dot(K)) / var1
+
+ # 6. Recover translation.
+ t = mu2 - scale*(R.dot(mu1))
+
+ # 7. Error:
+ S1_hat = scale*R.dot(S1) + t
+
+ if transposed:
+ S1_hat = S1_hat.T
+
+ return S1_hat
+
+def compute_similarity_transform_batch(S1, S2):
+ """Batched version of compute_similarity_transform."""
+ S1_hat = np.zeros_like(S1)
+ for i in range(S1.shape[0]):
+ S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
+ return S1_hat
+
+def reconstruction_error(S1, S2, reduction='mean'):
+ """Do Procrustes alignment and compute reconstruction error."""
+ S1_hat = compute_similarity_transform_batch(S1, S2)
+ re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
+ if reduction == 'mean':
+ re = re.mean()
+ elif reduction == 'sum':
+ re = re.sum()
+ return re
+
+
+def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
+ """Do Procrustes alignment and compute reconstruction error."""
+ S1_hat = compute_similarity_transform_batch(S1, S2)
+ S1_hat = S1_hat[:,J24_TO_J14,:]
+ S2 = S2[:,J24_TO_J14,:]
+ re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
+ if reduction == 'mean':
+ re = re.mean()
+ elif reduction == 'sum':
+ re = re.sum()
+ return re
+
+def get_alignMesh(S1, S2, reduction='mean'):
+ """Do Procrustes alignment and compute reconstruction error."""
+ S1_hat = compute_similarity_transform_batch(S1, S2)
+ re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
+ if reduction == 'mean':
+ re = re.mean()
+ elif reduction == 'sum':
+ re = re.sum()
+ return re, S1_hat, S2
diff --git a/src/custom_mesh_graphormer/utils/miscellaneous.py b/src/custom_mesh_graphormer/utils/miscellaneous.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de72c69c8fcd4502dc5e8b58656b3ee08db7554
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/miscellaneous.py
@@ -0,0 +1,171 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import errno
+import os
+import os.path as op
+import re
+import logging
+import numpy as np
+import torch
+import random
+import shutil
+from .comm import is_main_process
+import yaml
+
+
+def mkdir(path):
+ # if it is the current folder, skip.
+ # otherwise the original code will raise FileNotFoundError
+ if path == '':
+ return
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def save_config(cfg, path):
+ if is_main_process():
+ with open(path, 'w') as f:
+ f.write(cfg.dump())
+
+
+def config_iteration(output_dir, max_iter):
+ save_file = os.path.join(output_dir, 'last_checkpoint')
+ iteration = -1
+ if os.path.exists(save_file):
+ with open(save_file, 'r') as f:
+ fname = f.read().strip()
+ model_name = os.path.basename(fname)
+ model_path = os.path.dirname(fname)
+ if model_name.startswith('model_') and len(model_name) == 17:
+ iteration = int(model_name[-11:-4])
+ elif model_name == "model_final":
+ iteration = max_iter
+ elif model_path.startswith('checkpoint-') and len(model_path) == 18:
+ iteration = int(model_path.split('-')[-1])
+ return iteration
+
+
+def get_matching_parameters(model, regexp, none_on_empty=True):
+ """Returns parameters matching regular expression"""
+ if not regexp:
+ if none_on_empty:
+ return {}
+ else:
+ return dict(model.named_parameters())
+ compiled_pattern = re.compile(regexp)
+ params = {}
+ for weight_name, weight in model.named_parameters():
+ if compiled_pattern.match(weight_name):
+ params[weight_name] = weight
+ return params
+
+
+def freeze_weights(model, regexp):
+ """Freeze weights based on regular expression."""
+ logger = logging.getLogger("maskrcnn_benchmark.trainer")
+ for weight_name, weight in get_matching_parameters(model, regexp).items():
+ weight.requires_grad = False
+ logger.info("Disabled training of {}".format(weight_name))
+
+
+def unfreeze_weights(model, regexp, backbone_freeze_at=-1,
+ is_distributed=False):
+ """Unfreeze weights based on regular expression.
+ This is helpful during training to unfreeze freezed weights after
+ other unfreezed weights have been trained for some iterations.
+ """
+ logger = logging.getLogger("maskrcnn_benchmark.trainer")
+ for weight_name, weight in get_matching_parameters(model, regexp).items():
+ weight.requires_grad = True
+ logger.info("Enabled training of {}".format(weight_name))
+ if backbone_freeze_at >= 0:
+ logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at))
+ if is_distributed:
+ model.module.backbone.body._freeze_backbone(backbone_freeze_at)
+ else:
+ model.backbone.body._freeze_backbone(backbone_freeze_at)
+
+
+def delete_tsv_files(tsvs):
+ for t in tsvs:
+ if op.isfile(t):
+ try_delete(t)
+ line = op.splitext(t)[0] + '.lineidx'
+ if op.isfile(line):
+ try_delete(line)
+
+
+def concat_files(ins, out):
+ mkdir(op.dirname(out))
+ out_tmp = out + '.tmp'
+ with open(out_tmp, 'wb') as fp_out:
+ for i, f in enumerate(ins):
+ logging.info('concating {}/{} - {}'.format(i, len(ins), f))
+ with open(f, 'rb') as fp_in:
+ shutil.copyfileobj(fp_in, fp_out, 1024*1024*10)
+ os.rename(out_tmp, out)
+
+
+def concat_tsv_files(tsvs, out_tsv):
+ concat_files(tsvs, out_tsv)
+ sizes = [os.stat(t).st_size for t in tsvs]
+ sizes = np.cumsum(sizes)
+ all_idx = []
+ for i, t in enumerate(tsvs):
+ for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
+ if i == 0:
+ all_idx.append(idx)
+ else:
+ all_idx.append(str(int(idx) + sizes[i - 1]))
+ with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f:
+ f.write('\n'.join(all_idx))
+
+
+def load_list_file(fname):
+ with open(fname, 'r') as fp:
+ lines = fp.readlines()
+ result = [line.strip() for line in lines]
+ if len(result) > 0 and result[-1] == '':
+ result = result[:-1]
+ return result
+
+
+def try_once(func):
+ def func_wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ logging.info('ignore error \n{}'.format(str(e)))
+ return func_wrapper
+
+
+@try_once
+def try_delete(f):
+ os.remove(f)
+
+
+def set_seed(seed, n_gpu):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if n_gpu > 0:
+ torch.cuda.manual_seed_all(seed)
+
+
+def print_and_run_cmd(cmd):
+ print(cmd)
+ os.system(cmd)
+
+
+def write_to_yaml_file(context, file_name):
+ with open(file_name, 'w') as fp:
+ yaml.dump(context, fp, encoding='utf-8')
+
+
+def load_from_yaml_file(yaml_file):
+ with open(yaml_file, 'r') as fp:
+ return yaml.load(fp, Loader=yaml.CLoader)
+
+
diff --git a/src/custom_mesh_graphormer/utils/renderer.py b/src/custom_mesh_graphormer/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b2f8a96e7c227466a07a19650e75228f5aa6860
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/renderer.py
@@ -0,0 +1,691 @@
+"""
+Rendering tools for 3D mesh visualization on 2D image.
+
+Parts of the code are taken from https://github.com/akanazawa/hmr
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import cv2
+import code
+from opendr.camera import ProjectPoints
+from opendr.renderer import ColoredRenderer, TexturedRenderer
+from opendr.lighting import LambertianPointLight
+import random
+
+
+# Rotate the points by a specified angle.
+def rotateY(points, angle):
+ ry = np.array([
+ [np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
+ [-np.sin(angle), 0., np.cos(angle)]
+ ])
+ return np.dot(points, ry)
+
+def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None):
+ """
+ joints is 3 x 19. but if not will transpose it.
+ 0: Right ankle
+ 1: Right knee
+ 2: Right hip
+ 3: Left hip
+ 4: Left knee
+ 5: Left ankle
+ 6: Right wrist
+ 7: Right elbow
+ 8: Right shoulder
+ 9: Left shoulder
+ 10: Left elbow
+ 11: Left wrist
+ 12: Neck
+ 13: Head top
+ 14: nose
+ 15: left_eye
+ 16: right_eye
+ 17: left_ear
+ 18: right_ear
+ """
+
+ if radius is None:
+ radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int))
+
+ colors = {
+ 'pink': (197, 27, 125), # L lower leg
+ 'light_pink': (233, 163, 201), # L upper leg
+ 'light_green': (161, 215, 106), # L lower arm
+ 'green': (77, 146, 33), # L upper arm
+ 'red': (215, 48, 39), # head
+ 'light_red': (252, 146, 114), # head
+ 'light_orange': (252, 141, 89), # chest
+ 'purple': (118, 42, 131), # R lower leg
+ 'light_purple': (175, 141, 195), # R upper
+ 'light_blue': (145, 191, 219), # R lower arm
+ 'blue': (69, 117, 180), # R upper arm
+ 'gray': (130, 130, 130), #
+ 'white': (255, 255, 255), #
+ }
+
+ image = input_image.copy()
+ input_is_float = False
+
+ if np.issubdtype(image.dtype, np.float):
+ input_is_float = True
+ max_val = image.max()
+ if max_val <= 2.: # should be 1 but sometimes it's slightly above 1
+ image = (image * 255).astype(np.uint8)
+ else:
+ image = (image).astype(np.uint8)
+
+ if joints.shape[0] != 2:
+ joints = joints.T
+ joints = np.round(joints).astype(int)
+
+ jcolors = [
+ 'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink',
+ 'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue',
+ 'purple', 'purple', 'red', 'green', 'green', 'white', 'white',
+ 'purple', 'purple', 'red', 'green', 'green', 'white', 'white'
+ ]
+
+ if joints.shape[1] == 19:
+ # parent indices -1 means no parents
+ parents = np.array([
+ 1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16
+ ])
+ # Left is light and right is dark
+ ecolors = {
+ 0: 'light_pink',
+ 1: 'light_pink',
+ 2: 'light_pink',
+ 3: 'pink',
+ 4: 'pink',
+ 5: 'pink',
+ 6: 'light_blue',
+ 7: 'light_blue',
+ 8: 'light_blue',
+ 9: 'blue',
+ 10: 'blue',
+ 11: 'blue',
+ 12: 'purple',
+ 17: 'light_green',
+ 18: 'light_green',
+ 14: 'purple'
+ }
+ elif joints.shape[1] == 14:
+ parents = np.array([
+ 1,
+ 2,
+ 8,
+ 9,
+ 3,
+ 4,
+ 7,
+ 8,
+ -1,
+ -1,
+ 9,
+ 10,
+ 13,
+ -1,
+ ])
+ ecolors = {
+ 0: 'light_pink',
+ 1: 'light_pink',
+ 2: 'light_pink',
+ 3: 'pink',
+ 4: 'pink',
+ 5: 'pink',
+ 6: 'light_blue',
+ 7: 'light_blue',
+ 10: 'light_blue',
+ 11: 'blue',
+ 12: 'purple'
+ }
+ elif joints.shape[1] == 21: # hand
+ parents = np.array([
+ -1,
+ 0,
+ 1,
+ 2,
+ 3,
+ 0,
+ 5,
+ 6,
+ 7,
+ 0,
+ 9,
+ 10,
+ 11,
+ 0,
+ 13,
+ 14,
+ 15,
+ 0,
+ 17,
+ 18,
+ 19,
+ ])
+ ecolors = {
+ 0: 'light_purple',
+ 1: 'light_green',
+ 2: 'light_green',
+ 3: 'light_green',
+ 4: 'light_green',
+ 5: 'pink',
+ 6: 'pink',
+ 7: 'pink',
+ 8: 'pink',
+ 9: 'light_blue',
+ 10: 'light_blue',
+ 11: 'light_blue',
+ 12: 'light_blue',
+ 13: 'light_red',
+ 14: 'light_red',
+ 15: 'light_red',
+ 16: 'light_red',
+ 17: 'purple',
+ 18: 'purple',
+ 19: 'purple',
+ 20: 'purple',
+ }
+ else:
+ print('Unknown skeleton!!')
+
+ for child in range(len(parents)):
+ point = joints[:, child]
+ # If invisible skip
+ if vis is not None and vis[child] == 0:
+ continue
+ if draw_edges:
+ cv2.circle(image, (point[0], point[1]), radius, colors['white'],
+ -1)
+ cv2.circle(image, (point[0], point[1]), radius - 1,
+ colors[jcolors[child]], -1)
+ else:
+ # cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1)
+ cv2.circle(image, (point[0], point[1]), radius - 1,
+ colors[jcolors[child]], 1)
+ # cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1)
+ pa_id = parents[child]
+ if draw_edges and pa_id >= 0:
+ if vis is not None and vis[pa_id] == 0:
+ continue
+ point_pa = joints[:, pa_id]
+ cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1,
+ colors[jcolors[pa_id]], -1)
+ if child not in ecolors.keys():
+ print('bad')
+ import ipdb
+ ipdb.set_trace()
+ cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]),
+ colors[ecolors[child]], radius - 2)
+
+ # Convert back in original dtype
+ if input_is_float:
+ if max_val <= 1.:
+ image = image.astype(np.float32) / 255.
+ else:
+ image = image.astype(np.float32)
+
+ return image
+
+def draw_text(input_image, content):
+ """
+ content is a dict. draws key: val on image
+ Assumes key is str, val is float
+ """
+ image = input_image.copy()
+ input_is_float = False
+ if np.issubdtype(image.dtype, np.float):
+ input_is_float = True
+ image = (image * 255).astype(np.uint8)
+
+ black = (255, 255, 0)
+ margin = 15
+ start_x = 5
+ start_y = margin
+ for key in sorted(content.keys()):
+ text = "%s: %.2g" % (key, content[key])
+ cv2.putText(image, text, (start_x, start_y), 0, 0.45, black)
+ start_y += margin
+
+ if input_is_float:
+ image = image.astype(np.float32) / 255.
+ return image
+
+def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000):
+ """Overlays gt_kp and pred_kp on img.
+ Draws vert with text.
+ Renderer is an instance of SMPLRenderer.
+ """
+ gt_vis = gt_kp[:, 2].astype(bool)
+ loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
+ debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss}
+ # Fix a flength so i can render this with persp correct scale
+ res = img.shape[1]
+ camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
+ rend_img = renderer.render(vertices, camera_t=camera_t,
+ img=img, use_bg=True,
+ focal_length=focal_length,
+ body_color=color)
+ rend_img = draw_text(rend_img, debug_text)
+
+ # Draw skeleton
+ gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
+ pred_joint = ((pred_kp + 1) * 0.5) * img_size
+ img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
+ skel_img = draw_skeleton(img_with_gt, pred_joint)
+
+ combined = np.hstack([skel_img, rend_img])
+
+ return combined
+
+def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000):
+ """Overlays gt_kp and pred_kp on img.
+ Draws vert with text.
+ Renderer is an instance of SMPLRenderer.
+ """
+ gt_vis = gt_kp[:, 2].astype(bool)
+ loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
+ debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000}
+ # Fix a flength so i can render this with persp correct scale
+ res = img.shape[1]
+ camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
+ rend_img = renderer.render(vertices, camera_t=camera_t,
+ img=img, use_bg=True,
+ focal_length=focal_length,
+ body_color=color)
+ rend_img = draw_text(rend_img, debug_text)
+
+ # Draw skeleton
+ gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
+ pred_joint = ((pred_kp + 1) * 0.5) * img_size
+ img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
+ skel_img = draw_skeleton(img_with_gt, pred_joint)
+
+ combined = np.hstack([skel_img, rend_img])
+
+ return combined
+
+
+
+def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000):
+ """Overlays gt_kp and pred_kp on img.
+ Draws vert with text.
+ Renderer is an instance of SMPLRenderer.
+ """
+ # Fix a flength so i can render this with persp correct scale
+ res = img.shape[1]
+ camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
+ rend_img = renderer.render(vertices_full, camera_t=camera_t,
+ img=img, use_bg=True,
+ focal_length=focal_length, body_color='light_blue')
+
+
+ heads_num, vertex_num, _ = attention.shape
+
+ all_head = np.zeros((vertex_num,vertex_num))
+
+ ###### find max
+ # for i in range(vertex_num):
+ # for j in range(vertex_num):
+ # all_head[i,j] = np.max(attention[:,i,j])
+
+ ##### find avg
+ for h in range(4):
+ att_per_img = attention[h]
+ all_head = all_head + att_per_img
+ all_head = all_head/4
+
+ col_sums = all_head.sum(axis=0)
+ all_head = all_head / col_sums[np.newaxis, :]
+
+
+ # code.interact(local=locals())
+
+ combined = []
+ if vertex_num>400: # body
+ selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12]
+ else: # hand
+ selected_joints = [0, 4, 8, 12, 16, 20]
+ # Draw attention
+ for ii in range(len(selected_joints)):
+ reference_id = selected_joints[ii]
+ ref_point = ref_points[reference_id]
+ attention_to_show = all_head[reference_id][14::]
+ min_v = np.min(attention_to_show)
+ max_v = np.max(attention_to_show)
+ norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
+
+ vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
+ ref_norm = ((ref_point + 1) * 0.5) * img_size
+ image = np.zeros_like(rend_img)
+
+ for jj in range(vertices_norm.shape[0]):
+ x = int(vertices_norm[jj,0])
+ y = int(vertices_norm[jj,1])
+ cv2.circle(image,(x,y), 1, (255,255,255), -1)
+
+ total_to_draw = []
+ for jj in range(vertices_norm.shape[0]):
+ thres = 0.0
+ if norm_attention_to_show[jj]>thres:
+ things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
+ total_to_draw.append(things)
+ # plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) )
+ total_to_draw.sort()
+ max_att_score = total_to_draw[-1][0]
+ for item in total_to_draw:
+ attention_score = item[0]
+ ref_point = item[1]
+ vertex = item[2]
+ plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
+ # code.interact(local=locals())
+ if len(combined)==0:
+ combined = image
+ else:
+ combined = np.hstack([combined, image])
+
+ final = np.hstack([img, combined, rend_img])
+
+ return final
+
+
+def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000):
+ """Overlays gt_kp and pred_kp on img.
+ Draws vert with text.
+ Renderer is an instance of SMPLRenderer.
+ """
+ # Fix a flength so i can render this with persp correct scale
+ res = img.shape[1]
+ camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
+ rend_img = renderer.render(vertices_full, camera_t=camera_t,
+ img=img, use_bg=True,
+ focal_length=focal_length, body_color=color)
+ heads_num, vertex_num, _ = attention.shape
+ all_head = np.zeros((vertex_num,vertex_num))
+
+ ##### compute avg attention for 4 attention heads
+ for h in range(4):
+ att_per_img = attention[h]
+ all_head = all_head + att_per_img
+ all_head = all_head/4
+
+ col_sums = all_head.sum(axis=0)
+ all_head = all_head / col_sums[np.newaxis, :]
+
+ combined = []
+ if vertex_num>400: # body
+ selected_joints = [7] # [6,7,4,5,13,12]
+ else: # hand
+ selected_joints = [0] # [0, 4, 8, 12, 16, 20]
+ # Draw attention
+ for ii in range(len(selected_joints)):
+ reference_id = selected_joints[ii]
+ ref_point = ref_points[reference_id]
+ attention_to_show = all_head[reference_id][14::]
+ min_v = np.min(attention_to_show)
+ max_v = np.max(attention_to_show)
+ norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
+ vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
+ ref_norm = ((ref_point + 1) * 0.5) * img_size
+ image = rend_img*0.4
+
+ total_to_draw = []
+ for jj in range(vertices_norm.shape[0]):
+ thres = 0.0
+ if norm_attention_to_show[jj]>thres:
+ things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
+ total_to_draw.append(things)
+ total_to_draw.sort()
+ max_att_score = total_to_draw[-1][0]
+ for item in total_to_draw:
+ attention_score = item[0]
+ ref_point = item[1]
+ vertex = item[2]
+ plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
+
+ for jj in range(vertices_norm.shape[0]):
+ x = int(vertices_norm[jj,0])
+ y = int(vertices_norm[jj,1])
+ cv2.circle(image,(x,y), 1, (255,255,255), -1)
+
+ if len(combined)==0:
+ combined = image
+ else:
+ combined = np.hstack([combined, image])
+
+ final = np.hstack([img, combined, rend_img])
+
+ return final
+
+
+def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000):
+ """Overlays gt_kp and pred_kp on img.
+ Draws vert with text.
+ Renderer is an instance of SMPLRenderer.
+ """
+ # Fix a flength so i can render this with persp correct scale
+ res = img.shape[1]
+ camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
+ rend_img = renderer.render(vertices, camera_t=camera_t,
+ img=img, use_bg=True,
+ focal_length=focal_length,
+ body_color=color)
+
+
+ combined = np.hstack([img, rend_img])
+
+ return combined
+
+
+def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None):
+ # 13,6,7,8,3,4,5
+ # att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \
+ # (209, 238, 245), (244, 200, 243), (233, 242, 216)]
+ att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)]
+
+
+ overlay = img.copy()
+ # output = img.copy()
+ # Plots one bounding box on image img
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
+
+ color = list(att_colors[color_index])
+ c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1]))
+ cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA)
+ cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
+
+
+
+def cam2pixel(cam_coord, f, c):
+ x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0]
+ y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1]
+ z = cam_coord[:, 2]
+ img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1)
+ return img_coord
+
+
+class Renderer(object):
+ """
+ Render mesh using OpenDR for visualization.
+ """
+
+ def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None):
+ self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] }
+ self.width = width
+ self.height = height
+ self.faces = faces
+ self.renderer = ColoredRenderer()
+
+ def render(self, vertices, faces=None, img=None,
+ camera_t=np.zeros([3], dtype=np.float32),
+ camera_rot=np.zeros([3], dtype=np.float32),
+ camera_center=None,
+ use_bg=False,
+ bg_color=(0.0, 0.0, 0.0),
+ body_color=None,
+ focal_length=5000,
+ disp_text=False,
+ gt_keyp=None,
+ pred_keyp=None,
+ **kwargs):
+ if img is not None:
+ height, width = img.shape[:2]
+ else:
+ height, width = self.height, self.width
+
+ if faces is None:
+ faces = self.faces
+
+ if camera_center is None:
+ camera_center = np.array([width * 0.5,
+ height * 0.5])
+
+ self.renderer.camera = ProjectPoints(rt=camera_rot,
+ t=camera_t,
+ f=focal_length * np.ones(2),
+ c=camera_center,
+ k=np.zeros(5))
+ dist = np.abs(self.renderer.camera.t.r[2] -
+ np.mean(vertices, axis=0)[2])
+ far = dist + 20
+
+ self.renderer.frustum = {'near': 1.0, 'far': far,
+ 'width': width,
+ 'height': height}
+
+ if img is not None:
+ if use_bg:
+ self.renderer.background_image = img
+ else:
+ self.renderer.background_image = np.ones_like(
+ img) * np.array(bg_color)
+
+ if body_color is None:
+ color = self.colors['light_blue']
+ else:
+ color = self.colors[body_color]
+
+ if isinstance(self.renderer, TexturedRenderer):
+ color = [1.,1.,1.]
+
+ self.renderer.set(v=vertices, f=faces,
+ vc=color, bgcolor=np.ones(3))
+ albedo = self.renderer.vc
+ # Construct Back Light (on back right corner)
+ yrot = np.radians(120)
+
+ self.renderer.vc = LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([-200, -100, -100]), yrot),
+ vc=albedo,
+ light_color=np.array([1, 1, 1]))
+
+ # Construct Left Light
+ self.renderer.vc += LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([800, 10, 300]), yrot),
+ vc=albedo,
+ light_color=np.array([1, 1, 1]))
+
+ # Construct Right Light
+ self.renderer.vc += LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
+ vc=albedo,
+ light_color=np.array([.7, .7, .7]))
+
+ return self.renderer.r
+
+
+ def render_vertex_color(self, vertices, faces=None, img=None,
+ camera_t=np.zeros([3], dtype=np.float32),
+ camera_rot=np.zeros([3], dtype=np.float32),
+ camera_center=None,
+ use_bg=False,
+ bg_color=(0.0, 0.0, 0.0),
+ vertex_color=None,
+ focal_length=5000,
+ disp_text=False,
+ gt_keyp=None,
+ pred_keyp=None,
+ **kwargs):
+ if img is not None:
+ height, width = img.shape[:2]
+ else:
+ height, width = self.height, self.width
+
+ if faces is None:
+ faces = self.faces
+
+ if camera_center is None:
+ camera_center = np.array([width * 0.5,
+ height * 0.5])
+
+ self.renderer.camera = ProjectPoints(rt=camera_rot,
+ t=camera_t,
+ f=focal_length * np.ones(2),
+ c=camera_center,
+ k=np.zeros(5))
+ dist = np.abs(self.renderer.camera.t.r[2] -
+ np.mean(vertices, axis=0)[2])
+ far = dist + 20
+
+ self.renderer.frustum = {'near': 1.0, 'far': far,
+ 'width': width,
+ 'height': height}
+
+ if img is not None:
+ if use_bg:
+ self.renderer.background_image = img
+ else:
+ self.renderer.background_image = np.ones_like(
+ img) * np.array(bg_color)
+
+ if vertex_color is None:
+ vertex_color = self.colors['light_blue']
+
+
+ self.renderer.set(v=vertices, f=faces,
+ vc=vertex_color, bgcolor=np.ones(3))
+ albedo = self.renderer.vc
+ # Construct Back Light (on back right corner)
+ yrot = np.radians(120)
+
+ self.renderer.vc = LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([-200, -100, -100]), yrot),
+ vc=albedo,
+ light_color=np.array([1, 1, 1]))
+
+ # Construct Left Light
+ self.renderer.vc += LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([800, 10, 300]), yrot),
+ vc=albedo,
+ light_color=np.array([1, 1, 1]))
+
+ # Construct Right Light
+ self.renderer.vc += LambertianPointLight(
+ f=self.renderer.f,
+ v=self.renderer.v,
+ num_verts=self.renderer.v.shape[0],
+ light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
+ vc=albedo,
+ light_color=np.array([.7, .7, .7]))
+
+ return self.renderer.r
\ No newline at end of file
diff --git a/src/custom_mesh_graphormer/utils/tsv_file.py b/src/custom_mesh_graphormer/utils/tsv_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e79a9ec55d7ee9a2ce622d7a4ae5af59f7d7c17
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/tsv_file.py
@@ -0,0 +1,162 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Definition of TSV class
+"""
+
+
+import logging
+import os
+import os.path as op
+
+
+def generate_lineidx(filein, idxout):
+ idxout_tmp = idxout + '.tmp'
+ with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout:
+ fsize = os.fstat(tsvin.fileno()).st_size
+ fpos = 0
+ while fpos!=fsize:
+ tsvout.write(str(fpos)+"\n")
+ tsvin.readline()
+ fpos = tsvin.tell()
+ os.rename(idxout_tmp, idxout)
+
+
+def read_to_character(fp, c):
+ result = []
+ while True:
+ s = fp.read(32)
+ assert s != ''
+ if c in s:
+ result.append(s[: s.index(c)])
+ break
+ else:
+ result.append(s)
+ return ''.join(result)
+
+
+class TSVFile(object):
+ def __init__(self, tsv_file, generate_lineidx=False):
+ self.tsv_file = tsv_file
+ self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
+ self._fp = None
+ self._lineidx = None
+ # the process always keeps the process which opens the file.
+ # If the pid is not equal to the currrent pid, we will re-open the file.
+ self.pid = None
+ # generate lineidx if not exist
+ if not op.isfile(self.lineidx) and generate_lineidx:
+ generate_lineidx(self.tsv_file, self.lineidx)
+
+ def __del__(self):
+ if self._fp:
+ self._fp.close()
+
+ def __str__(self):
+ return "TSVFile(tsv_file='{}')".format(self.tsv_file)
+
+ def __repr__(self):
+ return str(self)
+
+ def num_rows(self):
+ self._ensure_lineidx_loaded()
+ return len(self._lineidx)
+
+ def seek(self, idx):
+ self._ensure_tsv_opened()
+ self._ensure_lineidx_loaded()
+ try:
+ pos = self._lineidx[idx]
+ except:
+ logging.info('{}-{}'.format(self.tsv_file, idx))
+ raise
+ self._fp.seek(pos)
+ return [s.strip() for s in self._fp.readline().split('\t')]
+
+ def seek_first_column(self, idx):
+ self._ensure_tsv_opened()
+ self._ensure_lineidx_loaded()
+ pos = self._lineidx[idx]
+ self._fp.seek(pos)
+ return read_to_character(self._fp, '\t')
+
+ def get_key(self, idx):
+ return self.seek_first_column(idx)
+
+ def __getitem__(self, index):
+ return self.seek(index)
+
+ def __len__(self):
+ return self.num_rows()
+
+ def _ensure_lineidx_loaded(self):
+ if self._lineidx is None:
+ logging.info('loading lineidx: {}'.format(self.lineidx))
+ with open(self.lineidx, 'r') as fp:
+ self._lineidx = [int(i.strip()) for i in fp.readlines()]
+
+ def _ensure_tsv_opened(self):
+ if self._fp is None:
+ self._fp = open(self.tsv_file, 'r')
+ self.pid = os.getpid()
+
+ if self.pid != os.getpid():
+ logging.info('re-open {} because the process id changed'.format(self.tsv_file))
+ self._fp = open(self.tsv_file, 'r')
+ self.pid = os.getpid()
+
+
+class CompositeTSVFile():
+ def __init__(self, file_list, seq_file, root='.'):
+ if isinstance(file_list, str):
+ self.file_list = load_list_file(file_list)
+ else:
+ assert isinstance(file_list, list)
+ self.file_list = file_list
+
+ self.seq_file = seq_file
+ self.root = root
+ self.initialized = False
+ self.initialize()
+
+ def get_key(self, index):
+ idx_source, idx_row = self.seq[index]
+ k = self.tsvs[idx_source].get_key(idx_row)
+ return '_'.join([self.file_list[idx_source], k])
+
+ def num_rows(self):
+ return len(self.seq)
+
+ def __getitem__(self, index):
+ idx_source, idx_row = self.seq[index]
+ return self.tsvs[idx_source].seek(idx_row)
+
+ def __len__(self):
+ return len(self.seq)
+
+ def initialize(self):
+ '''
+ this function has to be called in init function if cache_policy is
+ enabled. Thus, let's always call it in init funciton to make it simple.
+ '''
+ if self.initialized:
+ return
+ self.seq = []
+ with open(self.seq_file, 'r') as fp:
+ for line in fp:
+ parts = line.strip().split('\t')
+ self.seq.append([int(parts[0]), int(parts[1])])
+ self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
+ self.initialized = True
+
+
+def load_list_file(fname):
+ with open(fname, 'r') as fp:
+ lines = fp.readlines()
+ result = [line.strip() for line in lines]
+ if len(result) > 0 and result[-1] == '':
+ result = result[:-1]
+ return result
+
+
diff --git a/src/custom_mesh_graphormer/utils/tsv_file_ops.py b/src/custom_mesh_graphormer/utils/tsv_file_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..70f2798b931fd937d100b2093228013eb6176113
--- /dev/null
+++ b/src/custom_mesh_graphormer/utils/tsv_file_ops.py
@@ -0,0 +1,116 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Basic operations for TSV files
+"""
+
+
+import os
+import os.path as op
+import json
+import numpy as np
+import base64
+import cv2
+from tqdm import tqdm
+import yaml
+from custom_mesh_graphormer.utils.miscellaneous import mkdir
+from custom_mesh_graphormer.utils.tsv_file import TSVFile
+
+
+def img_from_base64(imagestring):
+ try:
+ jpgbytestring = base64.b64decode(imagestring)
+ nparr = np.frombuffer(jpgbytestring, np.uint8)
+ r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
+ return r
+ except ValueError:
+ return None
+
+def load_linelist_file(linelist_file):
+ if linelist_file is not None:
+ line_list = []
+ with open(linelist_file, 'r') as fp:
+ for i in fp:
+ line_list.append(int(i.strip()))
+ return line_list
+
+def tsv_writer(values, tsv_file, sep='\t'):
+ mkdir(op.dirname(tsv_file))
+ lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
+ idx = 0
+ tsv_file_tmp = tsv_file + '.tmp'
+ lineidx_file_tmp = lineidx_file + '.tmp'
+ with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
+ assert values is not None
+ for value in values:
+ assert value is not None
+ value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
+ v = '{0}\n'.format(sep.join(map(str, value)))
+ fp.write(v)
+ fpidx.write(str(idx) + '\n')
+ idx = idx + len(v)
+ os.rename(tsv_file_tmp, tsv_file)
+ os.rename(lineidx_file_tmp, lineidx_file)
+
+def tsv_reader(tsv_file, sep='\t'):
+ with open(tsv_file, 'r') as fp:
+ for i, line in enumerate(fp):
+ yield [x.strip() for x in line.split(sep)]
+
+def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'):
+ if save_file is not None:
+ return save_file
+ return op.splitext(tsv_file)[0] + append_str
+
+def get_line_list(linelist_file=None, num_rows=None):
+ if linelist_file is not None:
+ return load_linelist_file(linelist_file)
+
+ if num_rows is not None:
+ return [i for i in range(num_rows)]
+
+def generate_hw_file(img_file, save_file=None):
+ rows = tsv_reader(img_file)
+ def gen_rows():
+ for i, row in tqdm(enumerate(rows)):
+ row1 = [row[0]]
+ img = img_from_base64(row[-1])
+ height = img.shape[0]
+ width = img.shape[1]
+ row1.append(json.dumps([{"height":height, "width": width}]))
+ yield row1
+
+ save_file = config_save_file(img_file, save_file, '.hw.tsv')
+ tsv_writer(gen_rows(), save_file)
+
+def generate_linelist_file(label_file, save_file=None, ignore_attrs=()):
+ # generate a list of image that has labels
+ # images with only ignore labels are not selected.
+ line_list = []
+ rows = tsv_reader(label_file)
+ for i, row in tqdm(enumerate(rows)):
+ labels = json.loads(row[1])
+ if labels:
+ if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \
+ for lab in labels]):
+ continue
+ line_list.append([i])
+
+ save_file = config_save_file(label_file, save_file, '.linelist.tsv')
+ tsv_writer(line_list, save_file)
+
+def load_from_yaml_file(yaml_file):
+ with open(yaml_file, 'r') as fp:
+ return yaml.load(fp, Loader=yaml.CLoader)
+
+def find_file_path_in_yaml(fname, root):
+ if fname is not None:
+ if op.isfile(fname):
+ return fname
+ elif op.isfile(op.join(root, fname)):
+ return op.join(root, fname)
+ else:
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
+ )
diff --git a/src/custom_mmpkg/__init__.py b/src/custom_mmpkg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8
--- /dev/null
+++ b/src/custom_mmpkg/__init__.py
@@ -0,0 +1 @@
+#Dummy file ensuring this package will be recognized
\ No newline at end of file
diff --git a/src/custom_mmpkg/custom_mmcv/__init__.py b/src/custom_mmpkg/custom_mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
diff --git a/src/custom_mmpkg/custom_mmcv/arraymisc/__init__.py b/src/custom_mmpkg/custom_mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+
+__all__ = ['quantize', 'dequantize']
diff --git a/src/custom_mmpkg/custom_mmcv/arraymisc/quantization.py b/src/custom_mmpkg/custom_mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/arraymisc/quantization.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/__init__.py b/src/custom_mmpkg/custom_mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+ DepthwiseSeparableConvModule, GeneralizedAttention,
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+
+__all__ = [
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/alexnet.py b/src/custom_mmpkg/custom_mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/alexnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return x
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+
+__all__ = [
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/activation.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0881d7201de63ea47c9e585eead35f5c12c1881f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/activation.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+]:
+ ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+ """Clamp activation layer.
+
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+
+ def __init__(self, min=-1., max=1.):
+ super(Clamp, self).__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/GELU.png
+
+ Examples::
+
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def forward(self, input):
+ return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg):
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+
+
+@PLUGIN_LAYERS.register_module()
+class ContextBlock(nn.Module):
+ """ContextBlock module in GCNet.
+
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+ (https://arxiv.org/abs/1904.11492) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ ratio (float): Ratio of channels of transform bottleneck
+ pooling_type (str): Pooling method for context modeling.
+ Options are 'att' and 'avg', stand for attention pooling and
+ average pooling respectively. Default: 'att'.
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
+ Options are 'channels_add', 'channel_mul', stand for channelwise
+ addition and multiplication respectively. Default: ('channel_add',)
+ """
+
+ _abbr_ = 'context_block'
+
+ def __init__(self,
+ in_channels,
+ ratio,
+ pooling_type='att',
+ fusion_types=('channel_add', )):
+ super(ContextBlock, self).__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.in_channels = in_channels
+ self.ratio = ratio
+ self.planes = int(in_channels * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ kaiming_init(self.conv_mask, mode='fan_in')
+ self.conv_mask.inited = True
+
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+
+ return out
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+
+ def forward(self, x):
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b82b6b35939be7031462d3febb6561e42854ea
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy()
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .registry import CONV_LAYERS
+
+
+def conv_ws_2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ eps=1e-5):
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+@CONV_LAYERS.register_module('ConvWS')
+class ConvWS2d(nn.Conv2d):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ eps=1e-5):
+ super(ConvWS2d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.eps = eps
+
+ def forward(self, x):
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.eps)
+
+
+@CONV_LAYERS.register_module(name='ConvAWS')
+class ConvAWS2d(nn.Conv2d):
+ """AWS (Adaptive Weight Standardization)
+
+ This is a variant of Weight Standardization
+ (https://arxiv.org/pdf/1903.10520.pdf)
+ It is used in DetectoRS to avoid NaN
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: True
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.register_buffer('weight_gamma',
+ torch.ones(self.out_channels, 1, 1, 1))
+ self.register_buffer('weight_beta',
+ torch.zeros(self.out_channels, 1, 1, 1))
+
+ def _get_weight(self, weight):
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ weight = (weight - mean) / std
+ weight = self.weight_gamma * weight + self.weight_beta
+ return weight
+
+ def forward(self, x):
+ weight = self._get_weight(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override default load function.
+
+ AWS overrides the function _load_from_state_dict to recover
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
+ weight_beta are found in the checkpoint, this function will return
+ after super()._load_from_state_dict. Otherwise, it will compute the
+ mean and std of the pretrained weights and store them in weight_beta
+ and weight_gamma.
+ """
+
+ self.weight_gamma.data.fill_(-1)
+ local_missing_keys = []
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, local_missing_keys,
+ unexpected_keys, error_msgs)
+ if self.weight_gamma.data.mean() > 0:
+ for k in local_missing_keys:
+ missing_keys.append(k)
+ return
+ weight = self.weight.data
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ self.weight_beta.data.copy_(mean)
+ self.weight_gamma.data.copy_(std)
+ missing_gamma_beta = [
+ k for k in local_missing_keys
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
+ ]
+ for k in missing_gamma_beta:
+ local_missing_keys.remove(k)
+ for k in local_missing_keys:
+ missing_keys.append(k)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .conv_module import ConvModule
+
+
+class DepthwiseSeparableConvModule(nn.Module):
+ """Depthwise separable convolution module.
+
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
+
+ This module can replace a ConvModule with the conv block replaced by two
+ conv block: depthwise conv block and pointwise conv block. The depthwise
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
+ conv block contains pointwise-conv/norm/activation layers. It should be
+ noted that there will be norm/activation layer in the depthwise conv block
+ if `norm_cfg` and `act_cfg` are specified.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
+ pointwise ConvModule. Default: None.
+ act_cfg (dict): Default activation config for both depthwise ConvModule
+ and pointwise ConvModule. Default: dict(type='ReLU').
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ kwargs (optional): Other shared arguments for depthwise and pointwise
+ ConvModule. See ConvModule for ref.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ dw_norm_cfg='default',
+ dw_act_cfg='default',
+ pw_norm_cfg='default',
+ pw_act_cfg='default',
+ **kwargs):
+ super(DepthwiseSeparableConvModule, self).__init__()
+ assert 'groups' not in kwargs, 'groups should not be specified'
+
+ # if norm/activation config of depthwise/pointwise ConvModule is not
+ # specified, use default config.
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+
+ # depthwise convolution
+ self.depthwise_conv = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_cfg=dw_norm_cfg,
+ act_cfg=dw_act_cfg,
+ **kwargs)
+
+ self.pointwise_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ norm_cfg=pw_norm_cfg,
+ act_cfg=pw_act_cfg,
+ **kwargs)
+
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/drop.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d192e3d3855d432bab5575406a09d5ff1aa94c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob=0.1):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob=0.5, inplace=False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg, default_args=None):
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class GeneralizedAttention(nn.Module):
+ """GeneralizedAttention module.
+
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+ (https://arxiv.org/abs/1711.07971) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ spatial_range (int): The spatial range. -1 indicates no spatial range
+ constraint. Default: -1.
+ num_heads (int): The head number of empirical_attention module.
+ Default: 9.
+ position_embedding_dim (int): The position embedding dimension.
+ Default: -1.
+ position_magnitude (int): A multiplier acting on coord difference.
+ Default: 1.
+ kv_stride (int): The feature stride acting on key/value feature map.
+ Default: 2.
+ q_stride (int): The feature stride acting on query feature map.
+ Default: 1.
+ attention_type (str): A binary indicator string for indicating which
+ items in generalized empirical_attention module are used.
+ Default: '1111'.
+
+ - '1000' indicates 'query and key content' (appr - appr) item,
+ - '0100' indicates 'query content and relative position'
+ (appr - position) item,
+ - '0010' indicates 'key content only' (bias - appr) item,
+ - '0001' indicates 'relative position only' (bias - position) item.
+ """
+
+ _abbr_ = 'gen_attention_block'
+
+ def __init__(self,
+ in_channels,
+ spatial_range=-1,
+ num_heads=9,
+ position_embedding_dim=-1,
+ position_magnitude=1,
+ kv_stride=2,
+ q_stride=1,
+ attention_type='1111'):
+
+ super(GeneralizedAttention, self).__init__()
+
+ # hard range means local range for non-local operation
+ self.position_embedding_dim = (
+ position_embedding_dim
+ if position_embedding_dim > 0 else in_channels)
+
+ self.position_magnitude = position_magnitude
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.spatial_range = spatial_range
+ self.kv_stride = kv_stride
+ self.q_stride = q_stride
+ self.attention_type = [bool(int(_)) for _ in attention_type]
+ self.qk_embed_dim = in_channels // num_heads
+ out_c = self.qk_embed_dim * num_heads
+
+ if self.attention_type[0] or self.attention_type[1]:
+ self.query_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.query_conv.kaiming_init = True
+
+ if self.attention_type[0] or self.attention_type[2]:
+ self.key_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.key_conv.kaiming_init = True
+
+ self.v_dim = in_channels // num_heads
+ self.value_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.v_dim * num_heads,
+ kernel_size=1,
+ bias=False)
+ self.value_conv.kaiming_init = True
+
+ if self.attention_type[1] or self.attention_type[3]:
+ self.appr_geom_fc_x = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_x.kaiming_init = True
+
+ self.appr_geom_fc_y = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_y.kaiming_init = True
+
+ if self.attention_type[2]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.appr_bias = nn.Parameter(appr_bias_value)
+
+ if self.attention_type[3]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.geom_bias = nn.Parameter(geom_bias_value)
+
+ self.proj_conv = nn.Conv2d(
+ in_channels=self.v_dim * num_heads,
+ out_channels=in_channels,
+ kernel_size=1,
+ bias=True)
+ self.proj_conv.kaiming_init = True
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ if self.spatial_range >= 0:
+ # only works when non local is after 3*3 conv
+ if in_channels == 256:
+ max_len = 84
+ elif in_channels == 512:
+ max_len = 42
+
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+ local_constraint_map = np.ones(
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+ for iy in range(max_len):
+ for ix in range(max_len):
+ local_constraint_map[
+ iy, ix,
+ max((iy - self.spatial_range) //
+ self.kv_stride, 0):min((iy + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len),
+ max((ix - self.spatial_range) //
+ self.kv_stride, 0):min((ix + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len)] = 0
+
+ self.local_constraint_map = nn.Parameter(
+ torch.from_numpy(local_constraint_map).byte(),
+ requires_grad=False)
+
+ if self.q_stride > 1:
+ self.q_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.q_stride)
+ else:
+ self.q_downsample = None
+
+ if self.kv_stride > 1:
+ self.kv_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.kv_stride)
+ else:
+ self.kv_downsample = None
+
+ self.init_weights()
+
+ def get_position_embedding(self,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ q_stride,
+ kv_stride,
+ device,
+ dtype,
+ feat_dim,
+ wave_length=1000):
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+ h_idxs = h_idxs.view((h, 1)) * q_stride
+
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+ w_idxs = w_idxs.view((w, 1)) * q_stride
+
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+ # (h, h_kv, 1)
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+ h_diff *= self.position_magnitude
+
+ # (w, w_kv, 1)
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+ w_diff *= self.position_magnitude
+
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
+
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+ dim_mat = dim_mat.view((1, 1, -1))
+
+ embedding_x = torch.cat(
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+ embedding_y = torch.cat(
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+ return embedding_x, embedding_y
+
+ def forward(self, x_input):
+ num_heads = self.num_heads
+
+ # use empirical_attention
+ if self.q_downsample is not None:
+ x_q = self.q_downsample(x_input)
+ else:
+ x_q = x_input
+ n, _, h, w = x_q.shape
+
+ if self.kv_downsample is not None:
+ x_kv = self.kv_downsample(x_input)
+ else:
+ x_kv = x_input
+ _, _, h_kv, w_kv = x_kv.shape
+
+ if self.attention_type[0] or self.attention_type[1]:
+ proj_query = self.query_conv(x_q).view(
+ (n, num_heads, self.qk_embed_dim, h * w))
+ proj_query = proj_query.permute(0, 1, 3, 2)
+
+ if self.attention_type[0] or self.attention_type[2]:
+ proj_key = self.key_conv(x_kv).view(
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+ if self.attention_type[1] or self.attention_type[3]:
+ position_embed_x, position_embed_y = self.get_position_embedding(
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+ x_input.device, x_input.dtype, self.position_embedding_dim)
+ # (n, num_heads, w, w_kv, dim)
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ # (n, num_heads, h, h_kv, dim)
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ position_feat_x /= math.sqrt(2)
+ position_feat_y /= math.sqrt(2)
+
+ # accelerate for saliency only
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy = torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, h_kv * w_kv)
+
+ h = 1
+ w = 1
+ else:
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+ if not self.attention_type[0]:
+ energy = torch.zeros(
+ n,
+ num_heads,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ dtype=x_input.dtype,
+ device=x_input.device)
+
+ # attention_type[0]: appr - appr
+ # attention_type[1]: appr - position
+ # attention_type[2]: bias - appr
+ # attention_type[3]: bias - position
+ if self.attention_type[0] or self.attention_type[2]:
+ if self.attention_type[0] and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[0]:
+ energy = torch.matmul(proj_query, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy += torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, 1, h_kv, w_kv)
+
+ if self.attention_type[1] or self.attention_type[3]:
+ if self.attention_type[1] and self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+
+ proj_query_reshape = (proj_query + geom_bias).\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+
+ energy_x = torch.matmul(
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
+ position_feat_x.permute(0, 1, 2, 4, 3))
+ energy_x = energy_x.\
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(
+ proj_query_reshape,
+ position_feat_y.permute(0, 1, 2, 4, 3))
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[1]:
+ proj_query_reshape = proj_query.\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ proj_query_reshape = proj_query_reshape.\
+ permute(0, 1, 3, 2, 4)
+ position_feat_x_reshape = position_feat_x.\
+ permute(0, 1, 2, 4, 3)
+ position_feat_y_reshape = position_feat_y.\
+ permute(0, 1, 2, 4, 3)
+
+ energy_x = torch.matmul(proj_query_reshape,
+ position_feat_x_reshape)
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(proj_query_reshape,
+ position_feat_y_reshape)
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, self.qk_embed_dim, 1).\
+ repeat(n, 1, 1, 1)
+
+ position_feat_x_reshape = position_feat_x.\
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
+
+ position_feat_y_reshape = position_feat_y.\
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+ energy += energy_x + energy_y
+
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+ if self.spatial_range >= 0:
+ cur_local_constraint_map = \
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+ contiguous().\
+ view(1, 1, h*w, h_kv*w_kv)
+
+ energy = energy.masked_fill_(cur_local_constraint_map,
+ float('-inf'))
+
+ attention = F.softmax(energy, 3)
+
+ proj_value = self.value_conv(x_kv)
+ proj_value_reshape = proj_value.\
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+ permute(0, 1, 3, 2)
+
+ out = torch.matmul(attention, proj_value_reshape).\
+ permute(0, 1, 3, 2).\
+ contiguous().\
+ view(n, self.v_dim * self.num_heads, h, w)
+
+ out = self.proj_conv(out)
+
+ # output is downsampled, upsample back to input size
+ if self.q_downsample is not None:
+ out = F.interpolate(
+ out,
+ size=x_input.shape[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ out = self.gamma * out + x_input
+ return out
+
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
+ kaiming_init(
+ m,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='uniform',
+ a=1)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
+
+ Args:
+ bias (float): Bias of the input feature map. Default: 1.0.
+ divisor (float): Divisor of the input feature map. Default: 2.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
+ super(HSigmoid, self).__init__()
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def forward(self, x):
+ x = (x + self.bias) / self.divisor
+
+ return x.clamp_(self.min_value, self.max_value)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSwish(nn.Module):
+ """Hard Swish Module.
+
+ This module applies the hard swish function:
+
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, inplace=False):
+ super(HSwish, self).__init__()
+ self.act = nn.ReLU6(inplace)
+
+ def forward(self, x):
+ return x * self.act(x + 3) / 6
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+
+import torch
+import torch.nn as nn
+
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+
+
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+ """Basic Non-local module.
+
+ This module is proposed in
+ "Non-local Neural Networks"
+ Paper reference: https://arxiv.org/abs/1711.07971
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ reduction (int): Channel reduction ratio. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+ Default: True.
+ conv_cfg (None | dict): The config dict for convolution layers.
+ If not specified, it will use `nn.Conv2d` for convolution layers.
+ Default: None.
+ norm_cfg (None | dict): The config dict for normalization layers.
+ Default: None. (This parameter is only applicable to conv_out.)
+ mode (str): Options are `gaussian`, `concatenation`,
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+ """
+
+ def __init__(self,
+ in_channels,
+ reduction=2,
+ use_scale=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(_NonLocalNd, self).__init__()
+ self.in_channels = in_channels
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.inter_channels = max(in_channels // reduction, 1)
+ self.mode = mode
+
+ if mode not in [
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+ ]:
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+ f"'embedded_gaussian' or 'dot_product', but got "
+ f'{mode} instead.')
+
+ # g, theta, phi are defaulted as `nn.ConvNd`.
+ # Here we use ConvModule for potential usage.
+ self.g = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.conv_out = ConvModule(
+ self.inter_channels,
+ self.in_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ if self.mode != 'gaussian':
+ self.theta = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.phi = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+
+ if self.mode == 'concatenation':
+ self.concat_project = ConvModule(
+ self.inter_channels * 2,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ act_cfg=dict(type='ReLU'))
+
+ self.init_weights(**kwargs)
+
+ def init_weights(self, std=0.01, zeros_init=True):
+ if self.mode != 'gaussian':
+ for m in [self.g, self.theta, self.phi]:
+ normal_init(m.conv, std=std)
+ else:
+ normal_init(self.g.conv, std=std)
+ if zeros_init:
+ if self.conv_out.norm_cfg is None:
+ constant_init(self.conv_out.conv, 0)
+ else:
+ constant_init(self.conv_out.norm, 0)
+ else:
+ if self.conv_out.norm_cfg is None:
+ normal_init(self.conv_out.conv, std=std)
+ else:
+ normal_init(self.conv_out.norm, std=std)
+
+ def gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def dot_product(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+
+ def concatenation(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ h = theta_x.size(2)
+ w = phi_x.size(3)
+ theta_x = theta_x.repeat(1, 1, 1, w)
+ phi_x = phi_x.repeat(1, 1, h, 1)
+
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
+ pairwise_weight = self.concat_project(concat_feature)
+ n, _, h, w = pairwise_weight.size()
+ pairwise_weight = pairwise_weight.view(n, h, w)
+ pairwise_weight /= pairwise_weight.shape[-1]
+
+ return pairwise_weight
+
+ def forward(self, x):
+ # Assume `reduction = 1`, then `inter_channels = C`
+ # or `inter_channels = C` when `mode="gaussian"`
+
+ # NonLocal1d x: [N, C, H]
+ # NonLocal2d x: [N, C, H, W]
+ # NonLocal3d x: [N, C, T, H, W]
+ n = x.size(0)
+
+ # NonLocal1d g_x: [N, H, C]
+ # NonLocal2d g_x: [N, HxW, C]
+ # NonLocal3d g_x: [N, TxHxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ pairwise_func = getattr(self, self.mode)
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # NonLocal1d y: [N, H, C]
+ # NonLocal2d y: [N, HxW, C]
+ # NonLocal3d y: [N, TxHxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # NonLocal1d y: [N, C, H]
+ # NonLocal2d y: [N, C, H, W]
+ # NonLocal3d y: [N, C, T, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ output = x + self.conv_out(y)
+
+ return output
+
+
+class NonLocal1d(_NonLocalNd):
+ """1D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv1d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv1d'),
+ **kwargs):
+ super(NonLocal1d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+@PLUGIN_LAYERS.register_module()
+class NonLocal2d(_NonLocalNd):
+ """2D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv2d').
+ """
+
+ _abbr_ = 'nonlocal_block'
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv2d'),
+ **kwargs):
+ super(NonLocal2d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+class NonLocal3d(_NonLocalNd):
+ """3D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv3d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv3d'),
+ **kwargs):
+ super(NonLocal3d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/norm.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7a4d5d1ec957e885c48afb2dac772b6f792fd2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv.utils import is_tuple_of
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+ """Build normalization layer.
+
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+
+ Returns:
+ (str, nn.Module): The first element is the layer name consisting of
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
+
+
+def is_norm(layer, exclude=None):
+ """Check if a layer is a normalization layer.
+
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+
+ if exclude and isinstance(layer, exclude):
+ return False
+
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/padding.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg, *args, **kwargs):
+ """Build padding layer.
+
+ Args:
+ cfg (None or dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,88 @@
+import inspect
+import platform
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+
+ Modified from `inflection lib
+ `_.
+
+ Example::
+
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ else:
+ return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg, postfix='', **kwargs):
+ """Build plugin layer.
+
+ Args:
+ cfg (None or dict): cfg should contain:
+ type (str): identify plugin layer type.
+ layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+
+ Returns:
+ tuple[str, nn.Module]:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created plugin layer
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ layer = plugin_layer(**kwargs, **cfg_)
+
+ return name, layer
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/registry.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..496c18796f08a9de159b489fbef278ded22749d8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from custom_mmpkg.custom_mmcv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/scale.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+ """A learnable scale parameter.
+
+ This layer scales the input by a learnable factor. It multiplies a
+ learnable scale parameter of shape (1,) with input of any shape.
+
+ Args:
+ scale (float): Initial value of scale factor. Default: 1.0
+ """
+
+ def __init__(self, scale=1.0):
+ super(Scale, self).__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+ def forward(self, x):
+ return x * self.scale
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/swish.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class Swish(nn.Module):
+ """Swish Module.
+
+ This module applies the swish function:
+
+ .. math::
+ Swish(x) = x * Sigmoid(x)
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self):
+ super(Swish, self).__init__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4cd4655d30aef5cecb65522bc6b854fb60eca8d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv import ConfigDict, deprecated_api_warning
+from custom_mmpkg.custom_mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from custom_mmpkg.custom_mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from custom_mmpkg.custom_mmcv.utils import build_from_cfg
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+ from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from custom_mmpkg.custom_mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+
+
+def build_attention(cfg, default_args=None):
+ """Builder for attention."""
+ return build_from_cfg(cfg, ATTENTION, default_args)
+
+
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
+def build_transformer_layer(cfg, default_args=None):
+ """Builder for transformer layer."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+
+
+def build_transformer_layer_sequence(cfg, default_args=None):
+ """Builder for transformer encoder and transformer decoder."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+
+
+@ATTENTION.register_module()
+class MultiheadAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ super(MultiheadAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ')
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = batch_first
+
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
+ super(FFN, self).__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class BaseTransformerLayer(BaseModule):
+ """Base `TransformerLayer` for vision transformer.
+
+ It can be built from `mmcv.ConfigDict` and support more flexible
+ customization, for example, using any number of `FFN or LN ` and
+ use different kinds of `attention` by specifying a list of `ConfigDict`
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+ when you specifying `norm` as the first element of `operation_order`.
+ More details about the `prenorm`: `On Layer Normalization in the
+ Transformer Architecture `_ .
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for `self_attention` or `cross_attention` modules,
+ The order of the configs in the list should be consistent with
+ corresponding attentions in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config. Default: None.
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Support `prenorm` when you specifying first element as `norm`.
+ Default:None.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ """
+
+ def __init__(self,
+ attn_cfgs=None,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ operation_order=None,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ')
+ ffn_cfgs[new_name] = kwargs[ori_name]
+
+ super(BaseTransformerLayer, self).__init__(init_cfg)
+
+ self.batch_first = batch_first
+
+ assert set(operation_order) & set(
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
+ set(operation_order), f'The operation_order of' \
+ f' {self.__class__.__name__} should ' \
+ f'contains all four operation type ' \
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
+ num_attn = operation_order.count('self_attn') + operation_order.count(
+ 'cross_attn')
+ if isinstance(attn_cfgs, dict):
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+ else:
+ assert num_attn == len(attn_cfgs), f'The length ' \
+ f'of attn_cfg {num_attn} is ' \
+ f'not consistent with the number of attention' \
+ f'in operation_order {operation_order}.'
+
+ self.num_attn = num_attn
+ self.operation_order = operation_order
+ self.norm_cfg = norm_cfg
+ self.pre_norm = operation_order[0] == 'norm'
+ self.attentions = ModuleList()
+
+ index = 0
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
+ attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
+ self.attentions.append(attention)
+ index += 1
+
+ self.embed_dims = self.attentions[0].embed_dims
+
+ self.ffns = ModuleList()
+ num_ffns = operation_order.count('ffn')
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+ self.ffns.append(
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
+
+ self.norms = ModuleList()
+ num_norms = operation_order.count('norm')
+ for _ in range(num_norms):
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerDecoderLayer`.
+
+ **kwargs contains some specific arguments of attentions.
+
+ Args:
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor] | None): 2D Tensor used in
+ calculation of corresponding attention. The length of
+ it should equal to the number of `attention` in
+ `operation_order`. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in `self_attn` layer.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+
+ norm_index = 0
+ attn_index = 0
+ ffn_index = 0
+ identity = query
+ if attn_masks is None:
+ attn_masks = [None for _ in range(self.num_attn)]
+ elif isinstance(attn_masks, torch.Tensor):
+ attn_masks = [
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+ ]
+ warnings.warn(f'Use same attn_mask in all attentions in '
+ f'{self.__class__.__name__} ')
+ else:
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
+ f'attn_masks {len(attn_masks)} must be equal ' \
+ f'to the number of attention in ' \
+ f'operation_order {self.num_attn}'
+
+ for layer in self.operation_order:
+ if layer == 'self_attn':
+ temp_key = temp_value = query
+ query = self.attentions[attn_index](
+ query,
+ temp_key,
+ temp_value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=query_key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'norm':
+ query = self.norms[norm_index](query)
+ norm_index += 1
+
+ elif layer == 'cross_attn':
+ query = self.attentions[attn_index](
+ query,
+ key,
+ value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'ffn':
+ query = self.ffns[ffn_index](
+ query, identity if self.pre_norm else None)
+ ffn_index += 1
+
+ return query
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class TransformerLayerSequence(BaseModule):
+ """Base class for TransformerEncoder and TransformerDecoder in vision
+ transformer.
+
+ As base-class of Encoder and Decoder in vision transformer.
+ Support customization such as specifying different kind
+ of `transformer_layer` in `transformer_coder`.
+
+ Args:
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+ it would be repeated `num_layer` times to a
+ list[`mmcv.ConfigDict`]. Default: None.
+ num_layers (int): The number of `TransformerLayer`. Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+ super(TransformerLayerSequence, self).__init__(init_cfg)
+ if isinstance(transformerlayers, dict):
+ transformerlayers = [
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
+ ]
+ else:
+ assert isinstance(transformerlayers, list) and \
+ len(transformerlayers) == num_layers
+ self.num_layers = num_layers
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
+ self.embed_dims = self.layers[0].embed_dims
+ self.pre_norm = self.layers[0].pre_norm
+
+ def forward(self,
+ query,
+ key,
+ value,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_queries, bs, embed_dims)`.
+ key (Tensor): The key tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
+ which is used in calculation of corresponding attention in
+ operation_order. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in self-attention
+ Default: None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: results with shape [num_queries, bs, embed_dims].
+ """
+ for layer in self.layers:
+ query = layer(
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_masks=attn_masks,
+ query_key_padding_mask=query_key_padding_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ return query
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super(PixelShufflePack, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+
+ def forward(self, x):
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py b/src/custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold):
+ return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, new_shape):
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+
+ @staticmethod
+ def backward(ctx, grad):
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+ def forward(self, x):
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/builder.py b/src/custom_mmpkg/custom_mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/resnet.py b/src/custom_mmpkg/custom_mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/resnet.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ super(BasicBlock, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(block,
+ inplanes,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ style='pytorch',
+ with_cp=False):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ with_cp=False):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages]
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/utils/__init__.py b/src/custom_mmpkg/custom_mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/utils/flops_counter.py b/src/custom_mmpkg/custom_mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a445d7a0ef90b371c74476c2b50b7b66eabc6d80
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,599 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import custom_mmpkg.custom_mmcv as mmcv
+
+
+def get_model_complexity_info(model,
+ input_shape,
+ print_per_layer_stat=True,
+ as_strings=True,
+ input_constructor=None,
+ flush=False,
+ ost=sys.stdout):
+ """Get complexity information of a model.
+
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
+ ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+
+ _ = flops_model(batch)
+
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops, units='GFLOPs', precision=2):
+ """Convert FLOPs number into a string.
+
+ Note that Here we take a multiply-add counts as one FLOP.
+
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted FLOPs number with units.
+
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params, units=None, precision=2):
+ """Convert parameter number into a string.
+
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted parameter number with units.
+
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+
+
+def print_model_with_flops(model,
+ total_flops,
+ total_params,
+ units='GFLOPs',
+ precision=3,
+ ost=sys.stdout,
+ flush=False):
+ """Print a model with FLOPs for each layer.
+
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+
+ Example:
+ >>> class ExampleModel(nn.Module):
+
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ '{:.3%} Params'.format(accumulated_num_params / total_params),
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr()
+ ])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+ """Calculate parameter number of a model.
+
+ Args:
+ model (nn.module): The model for parameter number calculation.
+
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+
+
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__(
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
+ net_main_module)
+
+ net_main_module.reset_flops_count()
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self):
+ """Compute average FLOPs cost.
+
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self):
+ """Activate the computation of mean flops consumption per image.
+
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+
+ def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+
+ module.__flops_handle__ = handle
+
+ self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self):
+ """Stop computing the mean flops consumption per image.
+
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+ """Reset statistics computed so far.
+
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input.shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+
+
+def norm_flops_counter_hook(module, input, output):
+ input = input[0]
+
+ batch_flops = np.prod(input.shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ input_height, input_width = input.shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_height
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * int(np.prod(output_dims))
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = len(input)
+ else:
+ pass
+ print('Warning! No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module):
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ print('Warning: variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module):
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+
+
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def get_modules_mapping():
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/utils/fuse_conv_bn.py b/src/custom_mmpkg/custom_mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv, bn):
+ """Fuse conv and bn into one module.
+
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_conv_bn(module):
+ """Recursively fuse conv and bn in a module.
+
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+
+ Args:
+ module (nn.Module): Module to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/utils/sync_bn.py b/src/custom_mmpkg/custom_mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f75291daab5cfbf367621cef62b0067aed9fbd0d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,59 @@
+import torch
+
+import custom_mmpkg.custom_mmcv as mmcv
+
+
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input):
+ return
+
+
+def revert_sync_batchnorm(module):
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(mmcv, 'ops'):
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/utils/weight_init.py b/src/custom_mmpkg/custom_mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5bb1755d2269829c113b98026aa0310a3d70cb
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,684 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from custom_mmpkg.custom_mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module, init_info):
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob):
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+
+
+def _get_bases_name(m):
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit(object):
+
+ def __init__(self, *, bias=0, bias_prob=None, layer=None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, val, **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, gain=1, distribution='normal', **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self, mean=0, std=1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, a=0, b=1, **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ distribution='normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+
+ def __call__(self, module):
+ super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit(object):
+ """Initialize module by loading a pretrained model.
+
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+
+ def __init__(self, checkpoint, prefix=None, map_location=None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+
+ def __call__(self, module):
+ from custom_mmpkg.custom_mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('mmcv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+
+
+def _initialize(module, cfg, wholemodule=False):
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+
+
+def _initialize_override(module, override, cfg):
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+
+ override = [override] if isinstance(override, dict) else override
+
+ for override_ in override:
+
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+
+
+def initialize(module, init_cfg):
+ """Initialize a module.
+
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/src/custom_mmpkg/custom_mmcv/cnn/vgg.py b/src/custom_mmpkg/custom_mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/cnn/vgg.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+from .utils import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes, out_planes, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation)
+
+
+def make_vgg_layer(inplanes,
+ planes,
+ num_blocks,
+ dilation=1,
+ with_bn=False,
+ ceil_mode=False):
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+class VGG(nn.Module):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+
+ def __init__(self,
+ depth,
+ with_bn=False,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3, 4),
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ ceil_mode=False,
+ with_last_pool=True):
+ super(VGG, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VGG, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/src/custom_mmpkg/custom_mmcv/engine/__init__.py b/src/custom_mmpkg/custom_mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/engine/test.py b/src/custom_mmpkg/custom_mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac64007f1784b8999b969b9fe4baca393c44d257
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/engine/test.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import torch
+import torch.distributed as dist
+
+import custom_mmpkg.custom_mmcv as mmcv
+from custom_mmpkg.custom_mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model, data_loader):
+ """Test model with a single gpu.
+
+ This method tests model with a single gpu and displays test progress bar.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results under cpu mode.
+
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_result = mmcv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/__init__.py b/src/custom_mmpkg/custom_mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/file_client.py b/src/custom_mmpkg/custom_mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..c060e6e88cce26d13b297d7aeca83e7b2be119bc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/file_client.py
@@ -0,0 +1,1148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterable, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import custom_mmpkg.custom_mmcv as mmcv
+from custom_mmpkg.custom_mmcv.utils.misc import has_method
+from custom_mmpkg.custom_mmcv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ """
+
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download a file from ``filepath`` and return a temporary path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ self.db_path = str(db_path)
+ self._client = lmdb.open(
+ self.db_path,
+ readonly=readonly,
+ lock=lock,
+ readahead=readahead,
+ **kwargs)
+
+ def get(self, filepath):
+ """Get values according to the filepath.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ filepath = str(filepath)
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ _allow_symlink = True
+
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, 'r', encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+
+ @contextmanager
+ def get_local_path(self, filepath: str) -> Iterable[str]:
+ """Download a file from ``filepath``.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+
+class FileClient:
+ """A general file client to access files in different backends.
+
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+ # This collection is used to record the overridden backends, and when a
+ # backend appears in the collection, the singleton pattern is disabled for
+ # that backend, because if the singleton pattern is used, then the object
+ # returned will be the backend before overwriting
+ _overridden_backends = set()
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+ _overridden_prefixes = set()
+
+ _instances = {}
+
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+
+ # if a backend was overridden, it will create a new object
+ if (arg_key in cls._instances
+ and backend not in cls._overridden_backends
+ and prefix not in cls._overridden_prefixes):
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+ cls._instances[arg_key] = _instance
+
+ return _instance
+
+ @property
+ def name(self):
+ return self.client.name
+
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://'
+ else ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+
+ if name in cls._backends and force:
+ cls._overridden_backends.add(name)
+ cls._backends[name] = backend
+
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ cls._overridden_prefixes.add(prefix)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+
+ This method can be used as a normal class method or a decorator.
+
+ .. code-block:: python
+
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ FileClient.register_backend('new', NewBackend)
+
+ or
+
+ .. code-block:: python
+
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+
+ return _register
+
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download data from ``filepath`` and write the data to local path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+
+ Args:
+ filepath (str or Path): Path to be read data.
+
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/handlers/__init__.py b/src/custom_mmpkg/custom_mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/handlers/base.py b/src/custom_mmpkg/custom_mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode='r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/handlers/json_handler.py b/src/custom_mmpkg/custom_mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+ """Set default json values for non-serializable values.
+
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/handlers/pickle_handler.py b/src/custom_mmpkg/custom_mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+ str_like = False
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/handlers/yaml_handler.py b/src/custom_mmpkg/custom_mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/io.py b/src/custom_mmpkg/custom_mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/io.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+
+from ..utils import is_list_of, is_str
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+
+def load(file, file_format=None, file_client_args=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler, file_formats):
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats, **kwargs):
+
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/src/custom_mmpkg/custom_mmcv/fileio/parse.py b/src/custom_mmpkg/custom_mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/fileio/parse.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+
+from .file_client import FileClient
+
+
+def list_from_file(filename,
+ prefix='',
+ offset=0,
+ max_num=0,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a list of strings.
+
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename,
+ key_type=str,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/src/custom_mmpkg/custom_mmcv/image/__init__.py b/src/custom_mmpkg/custom_mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_lighting, adjust_sharpness, auto_contrast,
+ clahe, imdenormalize, imequalize, iminvert,
+ imnormalize, imnormalize_, lut_transform, posterize,
+ solarize)
+
+__all__ = [
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/image/colorspace.py b/src/custom_mmpkg/custom_mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/colorspace.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+
+def imconvert(img, src, dst):
+ """Convert an image from the src colorspace to dst colorspace.
+
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+
+def bgr2gray(img, keepdim=False):
+ """Convert a BGR image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def rgb2gray(img, keepdim=False):
+ """Convert a RGB image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def gray2bgr(img):
+ """Convert a grayscale image to BGR image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+
+
+def gray2rgb(img):
+ """Convert a grayscale image to RGB image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def convert_color_factory(src, dst):
+
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+ def convert_color(img):
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+
+ Args:
+ img (ndarray or str): The input image.
+
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+
+ return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/src/custom_mmpkg/custom_mmcv/image/geometric.py b/src/custom_mmpkg/custom_mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/geometric.py
@@ -0,0 +1,728 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+
+if Image is not None:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+
+
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+ new_size = _scale_size((w, h), scale_factor)
+
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+
+
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+
+
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the
+ last value on the edge. For example, padding [1, 2, 3, 4]
+ with 2 elements on both sides in reflect mode will result
+ in [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with
+ 2 elements on both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ ndarray: The padded image.
+ """
+
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
+
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+
+ return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+
+ Returns:
+ ndarray: The cutout image.
+ """
+
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+
+ return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+
+
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+
+
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/src/custom_mmpkg/custom_mmcv/image/io.py b/src/custom_mmpkg/custom_mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe4400ddc5751cd01a554131b33eca3154e4ca7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/io.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
+
+from custom_mmpkg.custom_mmcv.utils import check_file_exist, is_str, mkdir_or_exist
+
+try:
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+ """Select a backend for image decoding.
+
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+
+
+def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
+ """Read an image.
+
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ check_file_exist(img_or_path,
+ f'img file does not exist: {img_or_path}')
+ if backend == 'turbojpeg':
+ with open(img_or_path, 'rb') as in_file:
+ img = jpeg.decode(in_file.read(),
+ _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ img = Image.open(img_or_path)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ img = tifffile.imread(img_or_path)
+ return img
+ else:
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imread(img_or_path, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
+ global imread_backend specified by ``mmcv.use_backend()`` will be
+ used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ buff = io.BytesIO(content)
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = osp.abspath(osp.dirname(file_path))
+ mkdir_or_exist(dir_name)
+ return cv2.imwrite(file_path, img, params)
diff --git a/src/custom_mmpkg/custom_mmcv/image/misc.py b/src/custom_mmpkg/custom_mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a1aae4510cdef05b9f61a664818c06760cea77
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/misc.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import custom_mmpkg.custom_mmcv as mmcv
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+ """Convert tensor to 3-channel images.
+
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W).
+ mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
+ std (tuple[float], optional): Standard deviation of images.
+ Defaults to (1, 1, 1).
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ Defaults to True.
+
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ assert len(mean) == 3
+ assert len(std) == 3
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = mmcv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/src/custom_mmpkg/custom_mmcv/image/photometric.py b/src/custom_mmpkg/custom_mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/image/photometric.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+
+
+def iminvert(img):
+ """Invert (negate) an image.
+
+ Args:
+ img (ndarray): Image to be inverted.
+
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+
+
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+
+
+def imequalize(img):
+ """Equalize the image histogram.
+
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+
+ Args:
+ img (ndarray): Image to be equalized.
+
+ Returns:
+ ndarray: The equalized image.
+ """
+
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+
+ Returns:
+ ndarray: The sharpened image.
+ """
+
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+
+
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
diff --git a/src/custom_mmpkg/custom_mmcv/model_zoo/deprecated.json b/src/custom_mmpkg/custom_mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/src/custom_mmpkg/custom_mmcv/model_zoo/mmcls.json b/src/custom_mmpkg/custom_mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/model_zoo/mmcls.json
@@ -0,0 +1,31 @@
+{
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
+}
diff --git a/src/custom_mmpkg/custom_mmcv/model_zoo/open_mmlab.json b/src/custom_mmpkg/custom_mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/src/custom_mmpkg/custom_mmcv/ops/__init__.py b/src/custom_mmpkg/custom_mmcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .contour_expand import contour_expand
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+ sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+ rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+ points_in_boxes_part)
+from .points_sampler import PointsSampler
+from .psa_mask import PSAMask
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+
+__all__ = [
+ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+ 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+ 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+ 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+ 'get_compiler_version', 'get_compiling_cuda_version',
+ 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+ 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+ 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+ 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+ 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+ 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+ 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+ 'border_align', 'gather_points', 'furthest_point_sample',
+ 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+ 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
+ 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
+ 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/ops/assign_score_withk.py b/src/custom_mmpkg/custom_mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/assign_score_withk.py
@@ -0,0 +1,123 @@
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+
+
+class AssignScoreWithK(Function):
+ r"""Perform weighted sum to generate output features according to scores.
+ Modified from `PAConv `_.
+
+ This is a memory-efficient CUDA implementation of assign_scores operation,
+ which first transform all point features with weight bank, then assemble
+ neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+
+ See the `paper `_ appendix Sec. D for
+ more detailed descriptions.
+
+ Note:
+ This implementation assumes using ``neighbor`` kernel input, which is
+ (point_features - center_features, point_features).
+ See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+ pointnet2/paconv.py#L128 for more details.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ scores,
+ point_features,
+ center_features,
+ knn_idx,
+ aggregate='sum'):
+ """
+ Args:
+ scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+ aggregate weight matrices in the weight bank.
+ ``npoint`` is the number of sampled centers.
+ ``K`` is the number of queried neighbors.
+ ``M`` is the number of weight matrices in the weight bank.
+ point_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed point features to be aggregated.
+ center_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed center features to be aggregated.
+ knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+ We assume the first idx in each row is the idx of the center.
+ aggregate (str, optional): Aggregation method.
+ Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+
+ Returns:
+ torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+ """
+ agg = {'sum': 0, 'avg': 1, 'max': 2}
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ output = point_features.new_zeros((B, out_dim, npoint, K))
+ ext_module.assign_score_withk_forward(
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ output,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg[aggregate])
+
+ ctx.save_for_backward(output, point_features, center_features, scores,
+ knn_idx)
+ ctx.agg = agg[aggregate]
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ """
+ Args:
+ grad_out (torch.Tensor): (B, out_dim, npoint, K)
+
+ Returns:
+ grad_scores (torch.Tensor): (B, npoint, K, M)
+ grad_point_features (torch.Tensor): (B, N, M, out_dim)
+ grad_center_features (torch.Tensor): (B, N, M, out_dim)
+ """
+ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+
+ agg = ctx.agg
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ grad_point_features = point_features.new_zeros(point_features.shape)
+ grad_center_features = center_features.new_zeros(center_features.shape)
+ grad_scores = scores.new_zeros(scores.shape)
+
+ ext_module.assign_score_withk_backward(
+ grad_out.contiguous(),
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ grad_point_features,
+ grad_center_features,
+ grad_scores,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg)
+
+ return grad_scores, grad_point_features, \
+ grad_center_features, None, None
+
+
+assign_score_withk = AssignScoreWithK.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/ball_query.py b/src/custom_mmpkg/custom_mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/ball_query.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+
+
+class BallQuery(Function):
+ """Find nearby points in spherical space."""
+
+ @staticmethod
+ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+ xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ min_radius (float): minimum radius of the balls.
+ max_radius (float): maximum radius of the balls.
+ sample_num (int): maximum number of features in the balls.
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
+
+ Returns:
+ Tensor: (B, npoint, nsample) tensor with the indices of
+ the features that form the query balls.
+ """
+ assert center_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+ assert min_radius < max_radius
+
+ B, N, _ = xyz.size()
+ npoint = center_xyz.size(1)
+ idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+
+ ext_module.ball_query_forward(
+ center_xyz,
+ xyz,
+ idx,
+ b=B,
+ n=N,
+ m=npoint,
+ min_radius=min_radius,
+ max_radius=max_radius,
+ nsample=sample_num)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/bbox.py b/src/custom_mmpkg/custom_mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/bbox.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
+ """Calculate overlap between two set of bboxes.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (Tensor): shape (m, 4) in format or empty.
+ bboxes2 (Tensor): shape (n, 4) in format or empty.
+ If aligned is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> bbox_overlaps(bboxes1, bboxes2)
+ tensor([[0.5000, 0.0000, 0.0000],
+ [0.0000, 0.0000, 1.0000],
+ [0.0000, 0.0000, 0.0000]])
+
+ Example:
+ >>> empty = torch.FloatTensor([])
+ >>> nonempty = torch.FloatTensor([
+ >>> [0, 0, 10, 9],
+ >>> ])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ mode_dict = {'iou': 0, 'iof': 1}
+ assert mode in mode_dict.keys()
+ mode_flag = mode_dict[mode]
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+ assert offset == 1 or offset == 0
+
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows, cols))
+ ext_module.bbox_overlaps(
+ bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
+ return ious
diff --git a/src/custom_mmpkg/custom_mmcv/ops/border_align.py b/src/custom_mmpkg/custom_mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/border_align.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+ @staticmethod
+ def forward(ctx, input, boxes, pool_size):
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+ 1. uniformly samples `pool_size`+1 positions on this line, involving \
+ the start and end points.
+ 2. the corresponding features on these points are computed by \
+ bilinear interpolation.
+ 3. max pooling over all the `pool_size`+1 positions are used for \
+ computing pooled feature.
+
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+
+ """
+
+ def __init__(self, pool_size):
+ super(BorderAlign, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, input, boxes):
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+ Returns:
+ Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/box_iou_rotated.py b/src/custom_mmpkg/custom_mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+
+
+def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
+ """Return intersection-over-union (Jaccard index) of boxes.
+
+ Both sets of boxes are expected to be in
+ (x_center, y_center, width, height, angle) format.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Arguments:
+ boxes1 (Tensor): rotated bboxes 1. \
+ It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ boxes2 (Tensor): rotated bboxes 2. \
+ It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (N, M) if aligned == False else shape (N,)
+ """
+ assert mode in ['iou', 'iof']
+ mode_dict = {'iou': 0, 'iof': 1}
+ mode_flag = mode_dict[mode]
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows * cols))
+ bboxes1 = bboxes1.contiguous()
+ bboxes2 = bboxes2.contiguous()
+ ext_module.box_iou_rotated(
+ bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+ if not aligned:
+ ious = ious.view(rows, cols)
+ return ious
diff --git a/src/custom_mmpkg/custom_mmcv/ops/carafe.py b/src/custom_mmpkg/custom_mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/carafe.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+ 'carafe_backward'
+])
+
+
+class CARAFENaiveFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFENaive',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ ext_module.carafe_naive_forward(
+ features,
+ masks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ grad_input = torch.zeros_like(features)
+ grad_masks = torch.zeros_like(masks)
+ ext_module.carafe_naive_backward(
+ grad_output.contiguous(),
+ features,
+ masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFENaive, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe_naive(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFE',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ routput = features.new_zeros(output.size(), requires_grad=False)
+ rfeatures = features.new_zeros(features.size(), requires_grad=False)
+ rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+ ext_module.carafe_forward(
+ features,
+ masks,
+ rfeatures,
+ routput,
+ rmasks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks, rfeatures)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks, rfeatures = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input = torch.zeros_like(features, requires_grad=False)
+ rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+ grad_input = torch.zeros_like(features, requires_grad=False)
+ grad_masks = torch.zeros_like(masks, requires_grad=False)
+ ext_module.carafe_backward(
+ grad_output.contiguous(),
+ rfeatures,
+ masks,
+ rgrad_output,
+ rgrad_input_hs,
+ rgrad_input,
+ rgrad_masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+ """ CARAFE: Content-Aware ReAssembly of FEatures
+
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ kernel_size (int): reassemble kernel size
+ group_size (int): reassemble group size
+ scale_factor (int): upsample ratio
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFE, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+@UPSAMPLE_LAYERS.register_module(name='carafe')
+class CARAFEPack(nn.Module):
+ """A unified package of CARAFE upsampler that contains: 1) channel
+ compressor 2) content encoder 3) CARAFE op.
+
+ Official implementation of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ channels (int): input feature channels
+ scale_factor (int): upsample ratio
+ up_kernel (int): kernel size of CARAFE op
+ up_group (int): group size of CARAFE op
+ encoder_kernel (int): kernel size of content encoder
+ encoder_dilation (int): dilation of content encoder
+ compressed_channels (int): output channels of channels compressor
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self,
+ channels,
+ scale_factor,
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1,
+ compressed_channels=64):
+ super(CARAFEPack, self).__init__()
+ self.channels = channels
+ self.scale_factor = scale_factor
+ self.up_kernel = up_kernel
+ self.up_group = up_group
+ self.encoder_kernel = encoder_kernel
+ self.encoder_dilation = encoder_dilation
+ self.compressed_channels = compressed_channels
+ self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+ 1)
+ self.content_encoder = nn.Conv2d(
+ self.compressed_channels,
+ self.up_kernel * self.up_kernel * self.up_group *
+ self.scale_factor * self.scale_factor,
+ self.encoder_kernel,
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+ dilation=self.encoder_dilation,
+ groups=1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ normal_init(self.content_encoder, std=0.001)
+
+ def kernel_normalizer(self, mask):
+ mask = F.pixel_shuffle(mask, self.scale_factor)
+ n, mask_c, h, w = mask.size()
+ # use float division explicitly,
+ # to void inconsistency while exporting to onnx
+ mask_channel = int(mask_c / float(self.up_kernel**2))
+ mask = mask.view(n, mask_channel, -1, h, w)
+
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+ mask = mask.view(n, mask_c, h, w).contiguous()
+
+ return mask
+
+ def feature_reassemble(self, x, mask):
+ x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+ return x
+
+ def forward(self, x):
+ compressed_x = self.channel_compressor(x)
+ mask = self.content_encoder(compressed_x)
+ mask = self.kernel_normalizer(mask)
+
+ x = self.feature_reassemble(x, mask)
+ return x
diff --git a/src/custom_mmpkg/custom_mmcv/ops/cc_attention.py b/src/custom_mmpkg/custom_mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6868974cae6e5a7b9a6841845f9fca909a27155
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/cc_attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from custom_mmpkg.custom_mmcv.cnn import PLUGIN_LAYERS, Scale
+
+
+def NEG_INF_DIAG(n, device):
+ """Returns a diagonal matrix of size [n, n].
+
+ The diagonal are all "-inf". This is for avoiding calculating the
+ overlapped element in the Criss-Cross twice.
+ """
+ return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+
+
+@PLUGIN_LAYERS.register_module()
+class CrissCrossAttention(nn.Module):
+ """Criss-Cross Attention Module.
+
+ .. note::
+ Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+ to a pure PyTorch and equivalent implementation. For more
+ details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+
+ Speed comparison for one forward pass
+
+ - Input size: [2,512,97,97]
+ - Device: 1 NVIDIA GeForce RTX 2080 Ti
+
+ +-----------------------+---------------+------------+---------------+
+ | |PyTorch version|CUDA version|Relative speed |
+ +=======================+===============+============+===============+
+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+ self.gamma = Scale(0.)
+ self.in_channels = in_channels
+
+ def forward(self, x):
+ """forward function of Criss-Cross Attention.
+
+ Args:
+ x (Tensor): Input feature. \
+ shape (batch_size, in_channels, height, width)
+ Returns:
+ Tensor: Output of the layer, with shape of \
+ (batch_size, in_channels, height, width)
+ """
+ B, C, H, W = x.size()
+ query = self.query_conv(x)
+ key = self.key_conv(x)
+ value = self.value_conv(x)
+ energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+ H, query.device)
+ energy_H = energy_H.transpose(1, 2)
+ energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+ attn = F.softmax(
+ torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
+ out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+ out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+
+ out = self.gamma(out) + x
+ out = out.contiguous()
+
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/contour_expand.py b/src/custom_mmpkg/custom_mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/contour_expand.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+
+
+def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
+ kernel_num):
+ """Expand kernel contours so that foreground pixels are assigned into
+ instances.
+
+ Arguments:
+ kernel_mask (np.array or Tensor): The instance kernel mask with
+ size hxw.
+ internal_kernel_label (np.array or Tensor): The instance internal
+ kernel label with size hxw.
+ min_kernel_area (int): The minimum kernel area.
+ kernel_num (int): The instance kernel number.
+
+ Returns:
+ label (list): The instance index map with size hxw.
+ """
+ assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+ assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(min_kernel_area, int)
+ assert isinstance(kernel_num, int)
+
+ if isinstance(kernel_mask, np.ndarray):
+ kernel_mask = torch.from_numpy(kernel_mask)
+ if isinstance(internal_kernel_label, np.ndarray):
+ internal_kernel_label = torch.from_numpy(internal_kernel_label)
+
+ if torch.__version__ == 'parrots':
+ if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+ label = []
+ else:
+ label = ext_module.contour_expand(
+ kernel_mask,
+ internal_kernel_label,
+ min_kernel_area=min_kernel_area,
+ kernel_num=kernel_num)
+ label = label.tolist()
+ else:
+ label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+ min_kernel_area, kernel_num)
+ return label
diff --git a/src/custom_mmpkg/custom_mmcv/ops/corner_pool.py b/src/custom_mmpkg/custom_mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/corner_pool.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
+ 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
+ 'right_pool_forward', 'right_pool_backward'
+])
+
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+
+
+class TopPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.top_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.top_pool_backward(input, grad_output)
+ return output
+
+
+class BottomPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.bottom_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.bottom_pool_backward(input, grad_output)
+ return output
+
+
+class LeftPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.left_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.left_pool_backward(input, grad_output)
+ return output
+
+
+class RightPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.right_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.right_pool_backward(input, grad_output)
+ return output
+
+
+class CornerPool(nn.Module):
+ """Corner Pooling.
+
+ Corner Pooling is a new type of pooling layer that helps a
+ convolutional network better localize corners of bounding boxes.
+
+ Please refer to https://arxiv.org/abs/1808.01244 for more details.
+ Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+ Args:
+ mode(str): Pooling orientation for the pooling layer
+
+ - 'bottom': Bottom Pooling
+ - 'left': Left Pooling
+ - 'right': Right Pooling
+ - 'top': Top Pooling
+
+ Returns:
+ Feature map after pooling.
+ """
+
+ pool_functions = {
+ 'bottom': BottomPoolFunction,
+ 'left': LeftPoolFunction,
+ 'right': RightPoolFunction,
+ 'top': TopPoolFunction,
+ }
+
+ cummax_dim_flip = {
+ 'bottom': (2, False),
+ 'left': (3, True),
+ 'right': (3, False),
+ 'top': (2, True),
+ }
+
+ def __init__(self, mode):
+ super(CornerPool, self).__init__()
+ assert mode in self.pool_functions
+ self.mode = mode
+ self.corner_pool = self.pool_functions[mode]
+
+ def forward(self, x):
+ if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+ if torch.onnx.is_in_onnx_export():
+ assert torch.__version__ >= '1.7.0', \
+ 'When `cummax` serves as an intermediate component whose '\
+ 'outputs is used as inputs for another modules, it\'s '\
+ 'expected that pytorch version must be >= 1.7.0, '\
+ 'otherwise Error appears like: `RuntimeError: tuple '\
+ 'appears in op that does not forward tuples, unsupported '\
+ 'kind: prim::PythonOp`.'
+
+ dim, flip = self.cummax_dim_flip[self.mode]
+ if flip:
+ x = x.flip(dim)
+ pool_tensor, _ = torch.cummax(x, dim=dim)
+ if flip:
+ pool_tensor = pool_tensor.flip(dim)
+ return pool_tensor
+ else:
+ return self.corner_pool.apply(x)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/correlation.py b/src/custom_mmpkg/custom_mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/correlation.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['correlation_forward', 'correlation_backward'])
+
+
+class CorrelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input1,
+ input2,
+ kernel_size=1,
+ max_displacement=1,
+ stride=1,
+ padding=1,
+ dilation=1,
+ dilation_patch=1):
+
+ ctx.save_for_backward(input1, input2)
+
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
+ patch_size = max_displacement * 2 + 1
+ ctx.patch_size = patch_size
+ dH, dW = ctx.stride = _pair(stride)
+ padH, padW = ctx.padding = _pair(padding)
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+ dilation_patch)
+
+ output_size = CorrelationFunction._output_size(ctx, input1)
+
+ output = input1.new_zeros(output_size)
+
+ ext_module.correlation_forward(
+ input1,
+ input2,
+ output,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
+
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
+ dH, dW = ctx.stride
+ grad_input1 = torch.zeros_like(input1)
+ grad_input2 = torch.zeros_like(input2)
+
+ ext_module.correlation_backward(
+ grad_output,
+ input1,
+ input2,
+ grad_input1,
+ grad_input2,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return grad_input1, grad_input2, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input1):
+ iH, iW = input1.size(2), input1.size(3)
+ batch_size = input1.size(0)
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ dH, dW = ctx.stride
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilatedKH = (kH - 1) * dilationH + 1
+ dilatedKW = (kW - 1) * dilationW + 1
+
+ oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+ oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+
+ output_size = (batch_size, patch_size, patch_size, oH, oW)
+ return output_size
+
+
+class Correlation(nn.Module):
+ r"""Correlation operator
+
+ This correlation operator works for optical flow correlation computation.
+
+ There are two batched tensors with shape :math:`(N, C, H, W)`,
+ and the correlation output's shape is :math:`(N, max\_displacement \times
+ 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+
+ where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding -
+ dilation \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation
+ \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+ window convolution between input1 and shifted input2,
+
+ .. math::
+ Corr(N_i, dx, dy) =
+ \sum_{c=0}^{C-1}
+ input1(N_i, c) \star
+ \mathcal{S}(input2(N_i, c), dy, dx)
+
+ where :math:`\star` is the valid 2d sliding window convolution operator,
+ and :math:`\mathcal{S}` means shifting the input features (auto-complete
+ zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+ [-max\_displacement \times dilation\_patch, max\_displacement \times
+ dilation\_patch]`.
+
+ Args:
+ kernel_size (int): The size of sliding window i.e. local neighborhood
+ representing the center points and involved in correlation
+ computation. Defaults to 1.
+ max_displacement (int): The radius for computing correlation volume,
+ but the actual working space can be dilated by dilation_patch.
+ Defaults to 1.
+ stride (int): The stride of the sliding blocks in the input spatial
+ dimensions. Defaults to 1.
+ padding (int): Zero padding added to all four sides of the input1.
+ Defaults to 0.
+ dilation (int): The spacing of local neighborhood that will involved
+ in correlation. Defaults to 1.
+ dilation_patch (int): The spacing between position need to compute
+ correlation. Defaults to 1.
+ """
+
+ def __init__(self,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> None:
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.dilation_patch = dilation_patch
+
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+ return CorrelationFunction.apply(input1, input2, self.kernel_size,
+ self.max_displacement, self.stride,
+ self.padding, self.dilation,
+ self.dilation_patch)
+
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(kernel_size={self.kernel_size}, '
+ s += f'max_displacement={self.max_displacement}, '
+ s += f'stride={self.stride}, '
+ s += f'padding={self.padding}, '
+ s += f'dilation={self.dilation}, '
+ s += f'dilation_patch={self.dilation_patch})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/deform_conv.py b/src/custom_mmpkg/custom_mmcv/ops/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e80c5af51a525915875a1f9cb030e77d24f190
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/deform_conv.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from custom_mmpkg.custom_mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'deform_conv_forward', 'deform_conv_backward_input',
+ 'deform_conv_backward_parameters'
+])
+
+
+class DeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g,
+ input,
+ offset,
+ weight,
+ stride,
+ padding,
+ dilation,
+ groups,
+ deform_groups,
+ bias=False,
+ im2col_step=32):
+ return g.op(
+ 'mmcv::MMCVDeformConv2d',
+ input,
+ offset,
+ weight,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups,
+ bias_i=bias,
+ im2col_step_i=im2col_step)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=False,
+ im2col_step=32):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ assert bias is False, 'Only support bias is False.'
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.im2col_step = im2col_step
+
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConv2dFunction._output_size(ctx, input, weight))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ ext_module.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) % cur_im2col_step
+ ) == 0, 'batch size must be divisible by im2col_step'
+
+ grad_output = grad_output.contiguous()
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ ext_module.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ ext_module.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ scale=1,
+ im2col_step=cur_im2col_step)
+
+ return grad_input, grad_offset, grad_weight, \
+ None, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+deform_conv2d = DeformConv2dFunction.apply
+
+
+class DeformConv2d(nn.Module):
+ r"""Deformable 2D convolution.
+
+ Applies a deformable 2D convolution over an input signal composed of
+ several input planes. DeformConv2d was described in the paper
+ `Deformable Convolutional Networks
+ `_
+
+ Note:
+ The argument ``im2col_step`` was added in version 1.3.17, which means
+ number of samples processed by the ``im2col_cuda_kernel`` per call.
+ It enables users to define ``batch_size`` and ``im2col_step`` more
+ flexibly and solved `issue mmcv#1440
+ `_.
+
+ Args:
+ in_channels (int): Number of channels in the input image.
+ out_channels (int): Number of channels produced by the convolution.
+ kernel_size(int, tuple): Size of the convolving kernel.
+ stride(int, tuple): Stride of the convolution. Default: 1.
+ padding (int or tuple): Zero-padding added to both sides of the input.
+ Default: 0.
+ dilation (int or tuple): Spacing between kernel elements. Default: 1.
+ groups (int): Number of blocked connections from input.
+ channels to output channels. Default: 1.
+ deform_groups (int): Number of deformable group partitions.
+ bias (bool): If True, adds a learnable bias to the output.
+ Default: False.
+ im2col_step (int): Number of samples processed by im2col_cuda_kernel
+ per call. It will work when ``batch_size`` > ``im2col_step``, but
+ ``batch_size`` must be divisible by ``im2col_step``. Default: 32.
+ `New in version 1.3.17.`
+ """
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='DeformConv2d')
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]],
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ groups: int = 1,
+ deform_groups: int = 1,
+ bias: bool = False,
+ im2col_step: int = 32) -> None:
+ super(DeformConv2d, self).__init__()
+
+ assert not bias, \
+ f'bias={bias} is not supported in DeformConv2d.'
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} cannot be divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} cannot be divisible by groups \
+ {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ self.im2col_step = im2col_step
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ # only weight, no bias
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # switch the initialization of `self.weight` to the standard kaiming
+ # method described in `Delving deep into rectifiers: Surpassing
+ # human-level performance on ImageNet classification` - He, K. et al.
+ # (2015), using a uniform distribution
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
+
+ def forward(self, x: Tensor, offset: Tensor) -> Tensor:
+ """Deformable Convolutional forward function.
+
+ Args:
+ x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
+ offset (Tensor): Offset for deformable convolution, shape
+ (B, deform_groups*kernel_size[0]*kernel_size[1]*2,
+ H_out, W_out), H_out, W_out are equal to the output's.
+
+ An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Returns:
+ Tensor: Output of the layer.
+ """
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
+ self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
+ offset = offset.contiguous()
+ out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+ pad_w].contiguous()
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels},\n'
+ s += f'out_channels={self.out_channels},\n'
+ s += f'kernel_size={self.kernel_size},\n'
+ s += f'stride={self.stride},\n'
+ s += f'padding={self.padding},\n'
+ s += f'dilation={self.dilation},\n'
+ s += f'groups={self.groups},\n'
+ s += f'deform_groups={self.deform_groups},\n'
+ # bias is not supported in DeformConv2d.
+ s += 'bias=False)'
+ return s
+
+
+@CONV_LAYERS.register_module('DCN')
+class DeformConv2dPack(DeformConv2d):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, DeformConvPack loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/deform_roi_pool.py b/src/custom_mmpkg/custom_mmcv/ops/deform_roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/deform_roi_pool.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward'])
+
+
+class DeformRoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, offset, output_size, spatial_scale,
+ sampling_ratio, gamma):
+ return g.op(
+ 'mmcv::MMCVDeformRoIPool',
+ input,
+ rois,
+ offset,
+ pooled_height_i=output_size[0],
+ pooled_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_f=sampling_ratio,
+ gamma_f=gamma)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ offset,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ if offset is None:
+ offset = input.new_zeros(0)
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = float(spatial_scale)
+ ctx.sampling_ratio = int(sampling_ratio)
+ ctx.gamma = float(gamma)
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+
+ ext_module.deform_roi_pool_forward(
+ input,
+ rois,
+ offset,
+ output,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+
+ ctx.save_for_backward(input, rois, offset)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, rois, offset = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(input.shape)
+ grad_offset = grad_output.new_zeros(offset.shape)
+
+ ext_module.deform_roi_pool_backward(
+ grad_output,
+ input,
+ rois,
+ offset,
+ grad_input,
+ grad_offset,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+ if grad_offset.numel() == 0:
+ grad_offset = None
+ return grad_input, None, grad_offset, None, None, None, None
+
+
+deform_roi_pool = DeformRoIPoolFunction.apply
+
+
+class DeformRoIPool(nn.Module):
+
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPool, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.gamma = float(gamma)
+
+ def forward(self, input, rois, offset=None):
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class DeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
+ sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class ModulatedDeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(ModulatedDeformRoIPoolPack,
+ self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ self.mask_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 1),
+ nn.Sigmoid())
+ self.mask_fc[2].weight.data.zero_()
+ self.mask_fc[2].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ mask = self.mask_fc(x.view(rois_num, -1))
+ mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1])
+ d = deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ return d * mask
diff --git a/src/custom_mmpkg/custom_mmcv/ops/deprecated_wrappers.py b/src/custom_mmpkg/custom_mmcv/ops/deprecated_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/deprecated_wrappers.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file is for backward compatibility.
+# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
+import warnings
+
+from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
+
+
+class Conv2d_deprecated(Conv2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class ConvTranspose2d_deprecated(ConvTranspose2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
+ 'deprecated in the future. Please import them from "mmcv.cnn" '
+ 'instead')
+
+
+class MaxPool2d_deprecated(MaxPool2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class Linear_deprecated(Linear):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
diff --git a/src/custom_mmpkg/custom_mmcv/ops/focal_loss.py b/src/custom_mmpkg/custom_mmcv/ops/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/focal_loss.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
+ 'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
+])
+
+
+class SigmoidFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSigmoidFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ output = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_forward(
+ input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input, target, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, target, weight = ctx.saved_tensors
+
+ grad_input = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_backward(
+ input,
+ target,
+ weight,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input.size(0)
+ return grad_input, None, None, None, None, None
+
+
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
+
+
+class SigmoidFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SigmoidFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
+
+
+class SoftmaxFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSoftmaxFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ channel_stats, _ = torch.max(input, dim=1)
+ input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
+ input_softmax.exp_()
+
+ channel_stats = input_softmax.sum(dim=1)
+ input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
+
+ output = input.new_zeros(input.size(0))
+ ext_module.softmax_focal_loss_forward(
+ input_softmax,
+ target,
+ weight,
+ output,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input_softmax, target, weight)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_softmax, target, weight = ctx.saved_tensors
+ buff = input_softmax.new_zeros(input_softmax.size(0))
+ grad_input = input_softmax.new_zeros(input_softmax.size())
+
+ ext_module.softmax_focal_loss_backward(
+ input_softmax,
+ target,
+ weight,
+ buff,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input_softmax.size(0)
+ return grad_input, None, None, None, None, None
+
+
+softmax_focal_loss = SoftmaxFocalLossFunction.apply
+
+
+class SoftmaxFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SoftmaxFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return softmax_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/furthest_point_sample.py b/src/custom_mmpkg/custom_mmcv/ops/furthest_point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/furthest_point_sample.py
@@ -0,0 +1,83 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'furthest_point_sampling_forward',
+ 'furthest_point_sampling_with_dist_forward'
+])
+
+
+class FurthestPointSampling(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_xyz: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) where N > num_points.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_xyz.is_contiguous()
+
+ B, N = points_xyz.size()[:2]
+ output = torch.cuda.IntTensor(B, num_points)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ ext_module.furthest_point_sampling_forward(
+ points_xyz,
+ temp,
+ output,
+ b=B,
+ n=N,
+ m=num_points,
+ )
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+class FurthestPointSamplingWithDist(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_dist: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_dist (Tensor): (B, N, N) Distance between each point pair.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_dist.is_contiguous()
+
+ B, N, _ = points_dist.size()
+ output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
+ temp = points_dist.new_zeros([B, N]).fill_(1e10)
+
+ ext_module.furthest_point_sampling_with_dist_forward(
+ points_dist, temp, output, b=B, n=N, m=num_points)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/fused_bias_leakyrelu.py b/src/custom_mmpkg/custom_mmcv/ops/fused_bias_leakyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/fused_bias_leakyrelu.py
@@ -0,0 +1,268 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
+
+
+class FusedBiasLeakyReLUFunctionBackward(Function):
+ """Calculate second order deviation.
+
+ This function is to compute the second order deviation for the fused leaky
+ relu operation.
+ """
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = ext_module.fused_bias_leakyrelu(
+ grad_output,
+ empty,
+ out,
+ act=3,
+ grad=1,
+ alpha=negative_slope,
+ scale=scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+
+ # The second order deviation, in fact, contains two parts, while the
+ # the first part is zero. Thus, we direct consider the second part
+ # which is similar with the first order deviation in implementation.
+ gradgrad_out = ext_module.fused_bias_leakyrelu(
+ gradgrad_input,
+ gradgrad_bias.to(out.dtype),
+ out,
+ act=3,
+ grad=1,
+ alpha=ctx.negative_slope,
+ scale=ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedBiasLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+
+ out = ext_module.fused_bias_leakyrelu(
+ input,
+ bias,
+ empty,
+ act=3,
+ grad=0,
+ alpha=negative_slope,
+ scale=scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedBiasLeakyReLU(nn.Module):
+ """Fused bias leaky ReLU.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ TODO: Implement the CPU version.
+
+ Args:
+ channel (int): The channel number of the feature map.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+ """
+
+ def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
+ super(FusedBiasLeakyReLU, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
+ self.scale)
+
+
+def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
+ """Fused bias leaky ReLU function.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ Args:
+ input (torch.Tensor): Input feature map.
+ bias (nn.Parameter): The bias from convolution operation.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+
+ Returns:
+ torch.Tensor: Feature map after non-linear activation.
+ """
+
+ if not input.is_cuda:
+ return bias_leakyrelu_ref(input, bias, negative_slope, scale)
+
+ return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
+ negative_slope, scale)
+
+
+def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
+
+ if bias is not None:
+ assert bias.ndim == 1
+ assert bias.shape[0] == x.shape[1]
+ x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
+
+ x = F.leaky_relu(x, negative_slope)
+ if scale != 1:
+ x = x * scale
+
+ return x
diff --git a/src/custom_mmpkg/custom_mmcv/ops/gather_points.py b/src/custom_mmpkg/custom_mmcv/ops/gather_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/gather_points.py
@@ -0,0 +1,57 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['gather_points_forward', 'gather_points_backward'])
+
+
+class GatherPoints(Function):
+ """Gather points with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) features to gather.
+ indices (Tensor): (B, M) where M is the number of points.
+
+ Returns:
+ Tensor: (B, C, M) where M is the number of points.
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+
+ B, npoint = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+
+ ext_module.gather_points_forward(
+ features, indices, output, b=B, c=C, n=N, npoints=npoint)
+
+ ctx.for_backwards = (indices, C, N)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(indices)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.gather_points_backward(
+ grad_out_data,
+ idx,
+ grad_features.data,
+ b=B,
+ c=C,
+ n=N,
+ npoints=npoint)
+ return grad_features, None
+
+
+gather_points = GatherPoints.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/group_points.py b/src/custom_mmpkg/custom_mmcv/ops/group_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/group_points.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+from .ball_query import ball_query
+from .knn import knn
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['group_points_forward', 'group_points_backward'])
+
+
+class QueryAndGroup(nn.Module):
+ """Groups points with a ball query of radius.
+
+ Args:
+ max_radius (float): The maximum radius of the balls.
+ If None is given, we will use kNN sampling instead of ball query.
+ sample_num (int): Maximum number of features to gather in the ball.
+ min_radius (float, optional): The minimum radius of the balls.
+ Default: 0.
+ use_xyz (bool, optional): Whether to use xyz.
+ Default: True.
+ return_grouped_xyz (bool, optional): Whether to return grouped xyz.
+ Default: False.
+ normalize_xyz (bool, optional): Whether to normalize xyz.
+ Default: False.
+ uniform_sample (bool, optional): Whether to sample uniformly.
+ Default: False
+ return_unique_cnt (bool, optional): Whether to return the count of
+ unique samples. Default: False.
+ return_grouped_idx (bool, optional): Whether to return grouped idx.
+ Default: False.
+ """
+
+ def __init__(self,
+ max_radius,
+ sample_num,
+ min_radius=0,
+ use_xyz=True,
+ return_grouped_xyz=False,
+ normalize_xyz=False,
+ uniform_sample=False,
+ return_unique_cnt=False,
+ return_grouped_idx=False):
+ super().__init__()
+ self.max_radius = max_radius
+ self.min_radius = min_radius
+ self.sample_num = sample_num
+ self.use_xyz = use_xyz
+ self.return_grouped_xyz = return_grouped_xyz
+ self.normalize_xyz = normalize_xyz
+ self.uniform_sample = uniform_sample
+ self.return_unique_cnt = return_unique_cnt
+ self.return_grouped_idx = return_grouped_idx
+ if self.return_unique_cnt:
+ assert self.uniform_sample, \
+ 'uniform_sample should be True when ' \
+ 'returning the count of unique samples'
+ if self.max_radius is None:
+ assert not self.normalize_xyz, \
+ 'can not normalize grouped xyz when max_radius is None'
+
+ def forward(self, points_xyz, center_xyz, features=None):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
+ """
+ # if self.max_radius is None, we will perform kNN instead of ball query
+ # idx is of shape [B, npoint, sample_num]
+ if self.max_radius is None:
+ idx = knn(self.sample_num, points_xyz, center_xyz, False)
+ idx = idx.transpose(1, 2).contiguous()
+ else:
+ idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
+ points_xyz, center_xyz)
+
+ if self.uniform_sample:
+ unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
+ for i_batch in range(idx.shape[0]):
+ for i_region in range(idx.shape[1]):
+ unique_ind = torch.unique(idx[i_batch, i_region, :])
+ num_unique = unique_ind.shape[0]
+ unique_cnt[i_batch, i_region] = num_unique
+ sample_ind = torch.randint(
+ 0,
+ num_unique, (self.sample_num - num_unique, ),
+ dtype=torch.long)
+ all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
+ idx[i_batch, i_region, :] = all_ind
+
+ xyz_trans = points_xyz.transpose(1, 2).contiguous()
+ # (B, 3, npoint, sample_num)
+ grouped_xyz = grouping_operation(xyz_trans, idx)
+ grouped_xyz_diff = grouped_xyz - \
+ center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
+ if self.normalize_xyz:
+ grouped_xyz_diff /= self.max_radius
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ # (B, C + 3, npoint, sample_num)
+ new_features = torch.cat([grouped_xyz_diff, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ assert (self.use_xyz
+ ), 'Cannot have not features and not use xyz as a feature!'
+ new_features = grouped_xyz_diff
+
+ ret = [new_features]
+ if self.return_grouped_xyz:
+ ret.append(grouped_xyz)
+ if self.return_unique_cnt:
+ ret.append(unique_cnt)
+ if self.return_grouped_idx:
+ ret.append(idx)
+ if len(ret) == 1:
+ return ret[0]
+ else:
+ return tuple(ret)
+
+
+class GroupAll(nn.Module):
+ """Group xyz with feature.
+
+ Args:
+ use_xyz (bool): Whether to use xyz.
+ """
+
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self,
+ xyz: torch.Tensor,
+ new_xyz: torch.Tensor,
+ features: torch.Tensor = None):
+ """
+ Args:
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ new_xyz (Tensor): new xyz coordinates of the features.
+ features (Tensor): (B, C, N) features to group.
+
+ Returns:
+ Tensor: (B, C + 3, 1, N) Grouped feature.
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ # (B, 3 + C, 1, N)
+ new_features = torch.cat([grouped_xyz, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupingOperation(Function):
+ """Group feature with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) tensor of features to group.
+ indices (Tensor): (B, npoint, nsample) the indices of
+ features to group with.
+
+ Returns:
+ Tensor: (B, C, npoint, nsample) Grouped features.
+ """
+ features = features.contiguous()
+ indices = indices.contiguous()
+
+ B, nfeatures, nsample = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+ ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
+ indices, output)
+
+ ctx.for_backwards = (indices, N)
+ return output
+
+ @staticmethod
+ def backward(ctx,
+ grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
+ of the output from forward.
+
+ Returns:
+ Tensor: (B, C, N) gradient of the features.
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.group_points_backward(B, C, N, npoint, nsample,
+ grad_out_data, idx,
+ grad_features.data)
+ return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/info.py b/src/custom_mmpkg/custom_mmcv/ops/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/info.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+
+import torch
+
+if torch.__version__ == 'parrots':
+ import parrots
+
+ def get_compiler_version():
+ return 'GCC ' + parrots.version.compiler
+
+ def get_compiling_cuda_version():
+ return parrots.version.cuda
+else:
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext(
+ '_ext', ['get_compiler_version', 'get_compiling_cuda_version'])
+
+ def get_compiler_version():
+ return ext_module.get_compiler_version()
+
+ def get_compiling_cuda_version():
+ return ext_module.get_compiling_cuda_version()
+
+
+def get_onnxruntime_op_path():
+ wildcard = os.path.join(
+ os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
+ '_ext_ort.*.so')
+
+ paths = glob.glob(wildcard)
+ if len(paths) > 0:
+ return paths[0]
+ else:
+ return ''
diff --git a/src/custom_mmpkg/custom_mmcv/ops/iou3d.py b/src/custom_mmpkg/custom_mmcv/ops/iou3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/iou3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
+ 'iou3d_nms_normal_forward'
+])
+
+
+def boxes_iou_bev(boxes_a, boxes_b):
+ """Calculate boxes IoU in the Bird's Eye View.
+
+ Args:
+ boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
+ boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
+
+ Returns:
+ ans_iou (torch.Tensor): IoU result with shape (M, N).
+ """
+ ans_iou = boxes_a.new_zeros(
+ torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
+
+ ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
+ boxes_b.contiguous(), ans_iou)
+
+ return ans_iou
+
+
+def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
+ """NMS function GPU implementation (for BEV boxes). The overlap of two
+ boxes for IoU calculation is defined as the exact overlapping area of the
+ two boxes. In this function, one can also set ``pre_max_size`` and
+ ``post_max_size``.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with the shape of [N, 5]
+ ([x1, y1, x2, y2, ry]).
+ scores (torch.Tensor): Scores of boxes with the shape of [N].
+ thresh (float): Overlap threshold of NMS.
+ pre_max_size (int, optional): Max size of boxes before NMS.
+ Default: None.
+ post_max_size (int, optional): Max size of boxes after NMS.
+ Default: None.
+
+ Returns:
+ torch.Tensor: Indexes after NMS.
+ """
+ assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ if pre_max_size is not None:
+ order = order[:pre_max_size]
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
+ keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
+ if post_max_size is not None:
+ keep = keep[:post_max_size]
+ return keep
+
+
+def nms_normal_bev(boxes, scores, thresh):
+ """Normal NMS function GPU implementation (for BEV boxes). The overlap of
+ two boxes for IoU calculation is defined as the exact overlapping area of
+ the two boxes WITH their yaw angle set to 0.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with shape (N, 5).
+ scores (torch.Tensor): Scores of predicted boxes with shape (N).
+ thresh (float): Overlap threshold of NMS.
+
+ Returns:
+ torch.Tensor: Remaining indices with scores in descending order.
+ """
+ assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
+ return order[keep[:num_out].cuda(boxes.device)].contiguous()
diff --git a/src/custom_mmpkg/custom_mmcv/ops/knn.py b/src/custom_mmpkg/custom_mmcv/ops/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/knn.py
@@ -0,0 +1,77 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
+
+
+class KNN(Function):
+ r"""KNN (CUDA) based on heap data structure.
+ Modified from `PAConv `_.
+
+ Find k-nearest points.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ k: int,
+ xyz: torch.Tensor,
+ center_xyz: torch.Tensor = None,
+ transposed: bool = False) -> torch.Tensor:
+ """
+ Args:
+ k (int): number of nearest neighbors.
+ xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
+ xyz coordinates of the features.
+ center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
+ False, else (B, 3, npoint). centers of the knn query.
+ Default: None.
+ transposed (bool, optional): whether the input tensors are
+ transposed. Should not explicitly use this keyword when
+ calling knn (=KNN.apply), just add the fourth param.
+ Default: False.
+
+ Returns:
+ Tensor: (B, k, npoint) tensor with the indices of
+ the features that form k-nearest neighbours.
+ """
+ assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
+
+ if center_xyz is None:
+ center_xyz = xyz
+
+ if transposed:
+ xyz = xyz.transpose(2, 1).contiguous()
+ center_xyz = center_xyz.transpose(2, 1).contiguous()
+
+ assert xyz.is_contiguous() # [B, N, 3]
+ assert center_xyz.is_contiguous() # [B, npoint, 3]
+
+ center_xyz_device = center_xyz.get_device()
+ assert center_xyz_device == xyz.get_device(), \
+ 'center_xyz and xyz should be put on the same device'
+ if torch.cuda.current_device() != center_xyz_device:
+ torch.cuda.set_device(center_xyz_device)
+
+ B, npoint, _ = center_xyz.shape
+ N = xyz.shape[1]
+
+ idx = center_xyz.new_zeros((B, npoint, k)).int()
+ dist2 = center_xyz.new_zeros((B, npoint, k)).float()
+
+ ext_module.knn_forward(
+ xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
+ # idx shape to [B, k, npoint]
+ idx = idx.transpose(2, 1).contiguous()
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None
+
+
+knn = KNN.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/masked_conv.py b/src/custom_mmpkg/custom_mmcv/ops/masked_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/masked_conv.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
+
+
+class MaskedConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, mask, weight, bias, padding, stride):
+ return g.op(
+ 'mmcv::MMCVMaskedConv2d',
+ features,
+ mask,
+ weight,
+ bias,
+ padding_i=padding,
+ stride_i=stride)
+
+ @staticmethod
+ def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
+ assert mask.dim() == 3 and mask.size(0) == 1
+ assert features.dim() == 4 and features.size(0) == 1
+ assert features.size()[2:] == mask.size()[1:]
+ pad_h, pad_w = _pair(padding)
+ stride_h, stride_w = _pair(stride)
+ if stride_h != 1 or stride_w != 1:
+ raise ValueError(
+ 'Stride could not only be 1 in masked_conv2d currently.')
+ out_channel, in_channel, kernel_h, kernel_w = weight.size()
+
+ batch_size = features.size(0)
+ out_h = int(
+ math.floor((features.size(2) + 2 * pad_h -
+ (kernel_h - 1) - 1) / stride_h + 1))
+ out_w = int(
+ math.floor((features.size(3) + 2 * pad_w -
+ (kernel_h - 1) - 1) / stride_w + 1))
+ mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
+ output = features.new_zeros(batch_size, out_channel, out_h, out_w)
+ if mask_inds.numel() > 0:
+ mask_h_idx = mask_inds[:, 0].contiguous()
+ mask_w_idx = mask_inds[:, 1].contiguous()
+ data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+ mask_inds.size(0))
+ ext_module.masked_im2col_forward(
+ features,
+ mask_h_idx,
+ mask_w_idx,
+ data_col,
+ kernel_h=kernel_h,
+ kernel_w=kernel_w,
+ pad_h=pad_h,
+ pad_w=pad_w)
+
+ masked_output = torch.addmm(1, bias[:, None], 1,
+ weight.view(out_channel, -1), data_col)
+ ext_module.masked_col2im_forward(
+ masked_output,
+ mask_h_idx,
+ mask_w_idx,
+ output,
+ height=out_h,
+ width=out_w,
+ channels=out_channel)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ return (None, ) * 5
+
+
+masked_conv2d = MaskedConv2dFunction.apply
+
+
+class MaskedConv2d(nn.Conv2d):
+ """A MaskedConv2d which inherits the official Conv2d.
+
+ The masked forward doesn't implement the backward function and only
+ supports the stride parameter to be 1 currently.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super(MaskedConv2d,
+ self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, input, mask=None):
+ if mask is None: # fallback to the normal Conv2d
+ return super(MaskedConv2d, self).forward(input)
+ else:
+ return masked_conv2d(input, mask, self.weight, self.bias,
+ self.padding)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/merge_cells.py b/src/custom_mmpkg/custom_mmcv/ops/merge_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/merge_cells.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..cnn import ConvModule
+
+
+class BaseMergeCell(nn.Module):
+ """The basic class for cells used in NAS-FPN and NAS-FCOS.
+
+ BaseMergeCell takes 2 inputs. After applying convolution
+ on them, they are resized to the target size. Then,
+ they go through binary_op, which depends on the type of cell.
+ If with_out_conv is True, the result of output will go through
+ another convolution layer.
+
+ Args:
+ in_channels (int): number of input channels in out_conv layer.
+ out_channels (int): number of output channels in out_conv layer.
+ with_out_conv (bool): Whether to use out_conv layer
+ out_conv_cfg (dict): Config dict for convolution layer, which should
+ contain "groups", "kernel_size", "padding", "bias" to build
+ out_conv layer.
+ out_norm_cfg (dict): Config dict for normalization layer in out_conv.
+ out_conv_order (tuple): The order of conv/norm/activation layers in
+ out_conv.
+ with_input1_conv (bool): Whether to use convolution on input1.
+ with_input2_conv (bool): Whether to use convolution on input2.
+ input_conv_cfg (dict): Config dict for building input1_conv layer and
+ input2_conv layer, which is expected to contain the type of
+ convolution.
+ Default: None, which means using conv2d.
+ input_norm_cfg (dict): Config dict for normalization layer in
+ input1_conv and input2_conv layer. Default: None.
+ upsample_mode (str): Interpolation method used to resize the output
+ of input1_conv and input2_conv to target size. Currently, we
+ support ['nearest', 'bilinear']. Default: 'nearest'.
+ """
+
+ def __init__(self,
+ fused_channels=256,
+ out_channels=256,
+ with_out_conv=True,
+ out_conv_cfg=dict(
+ groups=1, kernel_size=3, padding=1, bias=True),
+ out_norm_cfg=None,
+ out_conv_order=('act', 'conv', 'norm'),
+ with_input1_conv=False,
+ with_input2_conv=False,
+ input_conv_cfg=None,
+ input_norm_cfg=None,
+ upsample_mode='nearest'):
+ super(BaseMergeCell, self).__init__()
+ assert upsample_mode in ['nearest', 'bilinear']
+ self.with_out_conv = with_out_conv
+ self.with_input1_conv = with_input1_conv
+ self.with_input2_conv = with_input2_conv
+ self.upsample_mode = upsample_mode
+
+ if self.with_out_conv:
+ self.out_conv = ConvModule(
+ fused_channels,
+ out_channels,
+ **out_conv_cfg,
+ norm_cfg=out_norm_cfg,
+ order=out_conv_order)
+
+ self.input1_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input1_conv else nn.Sequential()
+ self.input2_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input2_conv else nn.Sequential()
+
+ def _build_input_conv(self, channel, conv_cfg, norm_cfg):
+ return ConvModule(
+ channel,
+ channel,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True)
+
+ @abstractmethod
+ def _binary_op(self, x1, x2):
+ pass
+
+ def _resize(self, x, size):
+ if x.shape[-2:] == size:
+ return x
+ elif x.shape[-2:] < size:
+ return F.interpolate(x, size=size, mode=self.upsample_mode)
+ else:
+ assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+ kernel_size = x.shape[-1] // size[-1]
+ x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+ return x
+
+ def forward(self, x1, x2, out_size=None):
+ assert x1.shape[:2] == x2.shape[:2]
+ assert out_size is None or len(out_size) == 2
+ if out_size is None: # resize to larger one
+ out_size = max(x1.size()[2:], x2.size()[2:])
+
+ x1 = self.input1_conv(x1)
+ x2 = self.input2_conv(x2)
+
+ x1 = self._resize(x1, out_size)
+ x2 = self._resize(x2, out_size)
+
+ x = self._binary_op(x1, x2)
+ if self.with_out_conv:
+ x = self.out_conv(x)
+ return x
+
+
+class SumCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
+
+ def _binary_op(self, x1, x2):
+ return x1 + x2
+
+
+class ConcatCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(ConcatCell, self).__init__(in_channels * 2, out_channels,
+ **kwargs)
+
+ def _binary_op(self, x1, x2):
+ ret = torch.cat([x1, x2], dim=1)
+ return ret
+
+
+class GlobalPoolingCell(BaseMergeCell):
+
+ def __init__(self, in_channels=None, out_channels=None, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def _binary_op(self, x1, x2):
+ x2_att = self.global_pool(x2).sigmoid()
+ return x2 + x2_att * x1
diff --git a/src/custom_mmpkg/custom_mmcv/ops/modulated_deform_conv.py b/src/custom_mmpkg/custom_mmcv/ops/modulated_deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..95b4828ef5ba35445856f6e19c0d565d8855c2ed
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/modulated_deform_conv.py
@@ -0,0 +1,282 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from custom_mmpkg.custom_mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
+
+
+class ModulatedDeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, offset, mask, weight, bias, stride, padding,
+ dilation, groups, deform_groups):
+ input_tensors = [input, offset, mask, weight]
+ if bias is not None:
+ input_tensors.append(bias)
+ return g.op(
+ 'mmcv::MMCVModulatedDeformConv2d',
+ *input_tensors,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(0) # fake tensor
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ ext_module.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ grad_output = grad_output.contiguous()
+ ext_module.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
+
+
+class ModulatedDeformConv2d(nn.Module):
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='ModulatedDeformConv2d')
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=True):
+ super(ModulatedDeformConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+
+@CONV_LAYERS.register_module('DCNv2')
+class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+ layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int): Same as nn.Conv2d, while tuple is not supported.
+ padding (int): Same as nn.Conv2d, while tuple is not supported.
+ dilation (int): Same as nn.Conv2d, while tuple is not supported.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConv2dPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/multi_scale_deform_attn.py b/src/custom_mmpkg/custom_mmcv/ops/multi_scale_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8696322b086872322185b6be4daf15f94d5981a0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/multi_scale_deform_attn.py
@@ -0,0 +1,358 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, once_differentiable
+
+from custom_mmpkg.custom_mmcv import deprecated_api_warning
+from custom_mmpkg.custom_mmcv.cnn import constant_init, xavier_init
+from custom_mmpkg.custom_mmcv.cnn.bricks.registry import ATTENTION
+from custom_mmpkg.custom_mmcv.runner import BaseModule
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+
+
+class MultiScaleDeformableAttnFunction(Function):
+
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
+ sampling_locations, attention_weights, im2col_step):
+ """GPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+ im2col_step (Tensor): The step used in image to column.
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ ctx.im2col_step = im2col_step
+ output = ext_module.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step=ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes,
+ value_level_start_index, sampling_locations,
+ attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ """GPU version of backward function.
+
+ Args:
+ grad_output (Tensor): Gradient
+ of output tensor of forward.
+
+ Returns:
+ Tuple[Tensor]: Gradient
+ of input tensors in forward.
+ """
+ value, value_spatial_shapes, value_level_start_index,\
+ sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value = torch.zeros_like(value)
+ grad_sampling_loc = torch.zeros_like(sampling_locations)
+ grad_attn_weight = torch.zeros_like(attention_weights)
+
+ ext_module.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output.contiguous(),
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ im2col_step=ctx.im2col_step)
+
+ return grad_value, None, None, \
+ grad_sampling_loc, grad_attn_weight, None
+
+
+def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
+ sampling_locations, attention_weights):
+ """CPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ bs, _, num_heads, embed_dims = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ =\
+ sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+ dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
+ # bs, H_*W_, num_heads, embed_dims ->
+ # bs, H_*W_, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, H_*W_ ->
+ # bs*num_heads, embed_dims, H_, W_
+ value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
+ bs * num_heads, embed_dims, H_, W_)
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :,
+ level].transpose(1, 2).flatten(0, 1)
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
+ attention_weights).sum(-1).view(bs, num_heads * embed_dims,
+ num_queries)
+ return output.transpose(1, 2).contiguous()
+
+
+@ATTENTION.register_module()
+class MultiScaleDeformableAttention(BaseModule):
+ """An attention module used in Deformable-Detr.
+
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for
+ each query in each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.1,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ if embed_dims % num_heads != 0:
+ raise ValueError(f'embed_dims must be divisible by num_heads, '
+ f'but got {embed_dims} and {num_heads}')
+ dim_per_head = embed_dims // num_heads
+ self.norm_cfg = norm_cfg
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ # you'd better set dim_per_head to a power of 2
+ # which is more efficient in the CUDA implementation
+ def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+ if not _is_power_of_2(dim_per_head):
+ warnings.warn(
+ "You'd better set embed_dims in "
+ 'MultiScaleDeformAttention to make '
+ 'the dimension of each attention head a power of 2 '
+ 'which is more efficient in our CUDA implementation.')
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dims,
+ num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self):
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.)
+ thetas = torch.arange(
+ self.num_heads,
+ dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init /
+ grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.num_heads, 1, 1,
+ 2).repeat(1, self.num_levels, self.num_points, 1)
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0., bias=0.)
+ xavier_init(self.value_proj, distribution='uniform', bias=0.)
+ xavier_init(self.output_proj, distribution='uniform', bias=0.)
+ self._is_init = True
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformableAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_padding_mask=None,
+ reference_points=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ **kwargs):
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`. Default
+ None.
+ reference_points (Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+
+ if value is None:
+ value = query
+
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets \
+ / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.num_points \
+ * reference_points[:, :, None, :, None, 2:] \
+ * 0.5
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be'
+ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/src/custom_mmpkg/custom_mmcv/ops/nms.py b/src/custom_mmpkg/custom_mmcv/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..080c0cf0f2ddef9c4d502b8011c85ed10eff94af
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/nms.py
@@ -0,0 +1,417 @@
+import os
+
+import numpy as np
+import torch
+
+from custom_mmpkg.custom_mmcv.utils import deprecated_api_warning
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
+
+
+# This function is modified from: https://github.com/pytorch/vision/
+class NMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ is_filtering_by_score = score_threshold > 0
+ if is_filtering_by_score:
+ valid_mask = scores > score_threshold
+ bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+ valid_inds = torch.nonzero(
+ valid_mask, as_tuple=False).squeeze(dim=1)
+
+ inds = ext_module.nms(
+ bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+
+ if max_num > 0:
+ inds = inds[:max_num]
+ if is_filtering_by_score:
+ inds = valid_inds[inds]
+ return inds
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ # TensorRT nms plugin is aligned with original nms in ONNXRuntime
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ if has_custom_op and (not is_trt_backend):
+ return g.op(
+ 'mmcv::NonMaxSuppression',
+ bboxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ offset_i=int(offset))
+ else:
+ from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+ from ..onnx.onnx_utils.symbolic_helper import _size_helper
+
+ boxes = unsqueeze(g, bboxes, 0)
+ scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
+
+ if max_num > 0:
+ max_num = g.op(
+ 'Constant',
+ value_t=torch.tensor(max_num, dtype=torch.long))
+ else:
+ dim = g.op('Constant', value_t=torch.tensor(0))
+ max_num = _size_helper(g, bboxes, dim)
+ max_output_per_class = max_num
+ iou_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([iou_threshold], dtype=torch.float))
+ score_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([score_threshold], dtype=torch.float))
+ nms_out = g.op('NonMaxSuppression', boxes, scores,
+ max_output_per_class, iou_threshold,
+ score_threshold)
+ return squeeze(
+ g,
+ select(
+ g, nms_out, 1,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor([2], dtype=torch.long))), 1)
+
+
+class SoftNMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ inds = ext_module.softnms(
+ boxes.cpu(),
+ scores.cpu(),
+ dets.cpu(),
+ iou_threshold=float(iou_threshold),
+ sigma=float(sigma),
+ min_score=float(min_score),
+ method=int(method),
+ offset=int(offset))
+ return dets, inds
+
+ @staticmethod
+ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ from packaging import version
+ assert version.parse(torch.__version__) >= version.parse('1.7.0')
+ nms_out = g.op(
+ 'mmcv::SoftNonMaxSuppression',
+ boxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ sigma_f=float(sigma),
+ min_score_f=float(min_score),
+ method_i=int(method),
+ offset_i=int(offset),
+ outputs=2)
+ return nms_out
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
+ """Dispatch to either CPU or GPU NMS implementations.
+
+ The input can be either torch tensor or numpy array. GPU NMS will be used
+ if the input is gpu tensor, otherwise CPU NMS
+ will be used. The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ score_threshold (float): score threshold for NMS.
+ max_num (int): maximum number of boxes after NMS.
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9],
+ >>> [49.3, 32.9, 51.0, 35.3],
+ >>> [49.2, 31.8, 51.0, 35.4],
+ >>> [35.1, 11.5, 39.1, 15.7],
+ >>> [35.6, 11.8, 39.3, 14.2],
+ >>> [35.3, 11.5, 39.9, 14.5],
+ >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\
+ dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = nms(boxes, scores, iou_threshold)
+ >>> assert len(inds) == len(dets) == 3
+ """
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+
+ if torch.__version__ == 'parrots':
+ indata_list = [boxes, scores]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'offset': int(offset)
+ }
+ inds = ext_module.nms(*indata_list, **indata_dict)
+ else:
+ inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+ score_threshold, max_num)
+ dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def soft_nms(boxes,
+ scores,
+ iou_threshold=0.3,
+ sigma=0.5,
+ min_score=1e-3,
+ method='linear',
+ offset=0):
+ """Dispatch to only CPU Soft NMS implementations.
+
+ The input can be either a torch tensor or numpy array.
+ The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ sigma (float): hyperparameter for gaussian method
+ min_score (float): score filter threshold
+ method (str): either 'linear' or 'gaussian'
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[4., 3., 5., 3.],
+ >>> [4., 3., 5., 4.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5)
+ >>> assert len(inds) == len(dets) == 5
+ """
+
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+ method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2}
+ assert method in method_dict.keys()
+
+ if torch.__version__ == 'parrots':
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'sigma': float(sigma),
+ 'min_score': min_score,
+ 'method': method_dict[method],
+ 'offset': int(offset)
+ }
+ inds = ext_module.softnms(*indata_list, **indata_dict)
+ else:
+ dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
+ float(iou_threshold), float(sigma),
+ float(min_score), method_dict[method],
+ int(offset))
+
+ dets = dets[:inds.size(0)]
+
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+ else:
+ return dets.to(device=boxes.device), inds.to(device=boxes.device)
+
+
+def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
+ """Performs non-maximum suppression in a batched fashion.
+
+ Modified from https://github.com/pytorch/vision/blob
+ /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
+ In order to perform NMS independently per class, we add an offset to all
+ the boxes. The offset is dependent only on the class idx, and is large
+ enough so that boxes from different classes do not overlap.
+
+ Arguments:
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ idxs (torch.Tensor): each index value correspond to a bbox cluster,
+ and NMS will not be applied between elements of different idxs,
+ shape (N, ).
+ nms_cfg (dict): specify nms type and other parameters like iou_thr.
+ Possible keys includes the following.
+
+ - iou_thr (float): IoU threshold used for NMS.
+ - split_thr (float): threshold number of boxes. In some cases the
+ number of boxes is large (e.g., 200k). To avoid OOM during
+ training, the users could set `split_thr` to a small value.
+ If the number of boxes is greater than the threshold, it will
+ perform NMS on each group of boxes separately and sequentially.
+ Defaults to 10000.
+ class_agnostic (bool): if true, nms is class agnostic,
+ i.e. IoU thresholding happens over all boxes,
+ regardless of the predicted class.
+
+ Returns:
+ tuple: kept dets and indice.
+ """
+ nms_cfg_ = nms_cfg.copy()
+ class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
+ if class_agnostic:
+ boxes_for_nms = boxes
+ else:
+ max_coordinate = boxes.max()
+ offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
+ boxes_for_nms = boxes + offsets[:, None]
+
+ nms_type = nms_cfg_.pop('type', 'nms')
+ nms_op = eval(nms_type)
+
+ split_thr = nms_cfg_.pop('split_thr', 10000)
+ # Won't split to multiple nms nodes when exporting to onnx
+ if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
+ dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
+ boxes = boxes[keep]
+ # -1 indexing works abnormal in TensorRT
+ # This assumes `dets` has 5 dimensions where
+ # the last dimension is score.
+ # TODO: more elegant way to handle the dimension issue.
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores = dets[:, 4]
+ else:
+ max_num = nms_cfg_.pop('max_num', -1)
+ total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores_after_nms = scores.new_zeros(scores.size())
+ for id in torch.unique(idxs):
+ mask = (idxs == id).nonzero(as_tuple=False).view(-1)
+ dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
+ total_mask[mask[keep]] = True
+ scores_after_nms[mask[keep]] = dets[:, -1]
+ keep = total_mask.nonzero(as_tuple=False).view(-1)
+
+ scores, inds = scores_after_nms[keep].sort(descending=True)
+ keep = keep[inds]
+ boxes = boxes[keep]
+
+ if max_num > 0:
+ keep = keep[:max_num]
+ boxes = boxes[:max_num]
+ scores = scores[:max_num]
+
+ return torch.cat([boxes, scores[:, None]], -1), keep
+
+
+def nms_match(dets, iou_threshold):
+ """Matched dets into different groups by NMS.
+
+ NMS match is Similar to NMS but when a bbox is suppressed, nms match will
+ record the indice of suppressed bbox and form a group with the indice of
+ kept bbox. In each group, indice is sorted as score order.
+
+ Arguments:
+ dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
+ iou_thr (float): IoU thresh for NMS.
+
+ Returns:
+ List[torch.Tensor | np.ndarray]: The outer list corresponds different
+ matched group, the inner Tensor corresponds the indices for a group
+ in score order.
+ """
+ if dets.shape[0] == 0:
+ matched = []
+ else:
+ assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
+ f'but get {dets.shape}'
+ if isinstance(dets, torch.Tensor):
+ dets_t = dets.detach().cpu()
+ else:
+ dets_t = torch.from_numpy(dets)
+ indata_list = [dets_t]
+ indata_dict = {'iou_threshold': float(iou_threshold)}
+ matched = ext_module.nms_match(*indata_list, **indata_dict)
+ if torch.__version__ == 'parrots':
+ matched = matched.tolist()
+
+ if isinstance(dets, torch.Tensor):
+ return [dets.new_tensor(m, dtype=torch.long) for m in matched]
+ else:
+ return [np.array(m, dtype=np.int) for m in matched]
+
+
+def nms_rotated(dets, scores, iou_threshold, labels=None):
+ """Performs non-maximum suppression (NMS) on the rotated boxes according to
+ their intersection-over-union (IoU).
+
+ Rotated NMS iteratively removes lower scoring rotated boxes which have an
+ IoU greater than iou_threshold with another (higher scoring) rotated box.
+
+ Args:
+ boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \
+ be in (x_ctr, y_ctr, width, height, angle_radian) format.
+ scores (Tensor): scores in shape (N, ).
+ iou_threshold (float): IoU thresh for NMS.
+ labels (Tensor): boxes' label in shape (N,).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ """
+ if dets.shape[0] == 0:
+ return dets, None
+ multi_label = labels is not None
+ if multi_label:
+ dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
+ else:
+ dets_wl = dets
+ _, order = scores.sort(0, descending=True)
+ dets_sorted = dets_wl.index_select(0, order)
+
+ if torch.__version__ == 'parrots':
+ keep_inds = ext_module.nms_rotated(
+ dets_wl,
+ scores,
+ order,
+ dets_sorted,
+ iou_threshold=iou_threshold,
+ multi_label=multi_label)
+ else:
+ keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
+ iou_threshold, multi_label)
+ dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
+ dim=1)
+ return dets, keep_inds
diff --git a/src/custom_mmpkg/custom_mmcv/ops/pixel_group.py b/src/custom_mmpkg/custom_mmcv/ops/pixel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/pixel_group.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
+
+
+def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
+ kernel_region_num, distance_threshold):
+ """Group pixels into text instances, which is widely used text detection
+ methods.
+
+ Arguments:
+ score (np.array or Tensor): The foreground score with size hxw.
+ mask (np.array or Tensor): The foreground mask with size hxw.
+ embedding (np.array or Tensor): The embedding with size hxwxc to
+ distinguish instances.
+ kernel_label (np.array or Tensor): The instance kernel index with
+ size hxw.
+ kernel_contour (np.array or Tensor): The kernel contour with size hxw.
+ kernel_region_num (int): The instance kernel region number.
+ distance_threshold (float): The embedding distance threshold between
+ kernel and pixel in one instance.
+
+ Returns:
+ pixel_assignment (List[List[float]]): The instance coordinate list.
+ Each element consists of averaged confidence, pixel number, and
+ coordinates (x_i, y_i for all pixels) in order.
+ """
+ assert isinstance(score, (torch.Tensor, np.ndarray))
+ assert isinstance(mask, (torch.Tensor, np.ndarray))
+ assert isinstance(embedding, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_region_num, int)
+ assert isinstance(distance_threshold, float)
+
+ if isinstance(score, np.ndarray):
+ score = torch.from_numpy(score)
+ if isinstance(mask, np.ndarray):
+ mask = torch.from_numpy(mask)
+ if isinstance(embedding, np.ndarray):
+ embedding = torch.from_numpy(embedding)
+ if isinstance(kernel_label, np.ndarray):
+ kernel_label = torch.from_numpy(kernel_label)
+ if isinstance(kernel_contour, np.ndarray):
+ kernel_contour = torch.from_numpy(kernel_contour)
+
+ if torch.__version__ == 'parrots':
+ label = ext_module.pixel_group(
+ score,
+ mask,
+ embedding,
+ kernel_label,
+ kernel_contour,
+ kernel_region_num=kernel_region_num,
+ distance_threshold=distance_threshold)
+ label = label.tolist()
+ label = label[0]
+ list_index = kernel_region_num
+ pixel_assignment = []
+ for x in range(kernel_region_num):
+ pixel_assignment.append(
+ np.array(
+ label[list_index:list_index + int(label[x])],
+ dtype=np.float))
+ list_index = list_index + int(label[x])
+ else:
+ pixel_assignment = ext_module.pixel_group(score, mask, embedding,
+ kernel_label, kernel_contour,
+ kernel_region_num,
+ distance_threshold)
+ return pixel_assignment
diff --git a/src/custom_mmpkg/custom_mmcv/ops/point_sample.py b/src/custom_mmpkg/custom_mmcv/ops/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f09ce3ce366b9f5050f04a5f611a338484b30e7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/point_sample.py
@@ -0,0 +1,336 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+
+from os import path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+
+
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input’s
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input’s corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+ # Apply default for grid_sample function zero padding
+ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+ im_padded = im_padded.view(n, c, -1)
+
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
+
+def is_in_onnx_export_without_custom_ops():
+ from custom_mmpkg.custom_mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ return torch.onnx.is_in_onnx_export(
+ ) and not osp.exists(ort_custom_op_path)
+
+
+def normalize(grid):
+ """Normalize input grid from [-1, 1] to [0, 1]
+ Args:
+ grid (Tensor): The grid to be normalize, range [-1, 1].
+ Returns:
+ Tensor: Normalized grid, range [0, 1].
+ """
+
+ return (grid + 1.0) / 2.0
+
+
+def denormalize(grid):
+ """Denormalize input grid from range [0, 1] to [-1, 1]
+ Args:
+ grid (Tensor): The grid to be denormalize, range [0, 1].
+ Returns:
+ Tensor: Denormalized grid, range [-1, 1].
+ """
+
+ return grid * 2.0 - 1.0
+
+
+def generate_grid(num_grid, size, device):
+ """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
+ space.
+
+ Args:
+ num_grid (int): The number of grids to sample, one for each region.
+ size (tuple(int, int)): The side size of the regular grid.
+ device (torch.device): Desired device of returned tensor.
+
+ Returns:
+ (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
+ contains coordinates for the regular grids.
+ """
+
+ affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
+ grid = F.affine_grid(
+ affine_trans, torch.Size((1, 1, *size)), align_corners=False)
+ grid = normalize(grid)
+ return grid.view(1, -1, 2).expand(num_grid, -1, -1)
+
+
+def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ Returns:
+ Tensor: Image based absolute point coordinates, shape (N, P, 2)
+ """
+
+ with torch.no_grad():
+ assert rel_roi_points.size(0) == rois.size(0)
+ assert rois.dim() == 2
+ assert rel_roi_points.dim() == 3
+ assert rel_roi_points.size(2) == 2
+ # remove batch idx
+ if rois.size(1) == 5:
+ rois = rois[:, 1:]
+ abs_img_points = rel_roi_points.clone()
+ # To avoid an error during exporting to onnx use independent
+ # variables instead inplace computation
+ xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+ ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+ xs += rois[:, None, 0]
+ ys += rois[:, None, 1]
+ abs_img_points = torch.stack([xs, ys], dim=2)
+ return abs_img_points
+
+
+def get_shape_from_feature_map(x):
+ """Get spatial resolution of input feature map considering exporting to
+ onnx mode.
+
+ Args:
+ x (torch.Tensor): Input tensor, shape (N, C, H, W)
+ Returns:
+ torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+ """
+ if torch.onnx.is_in_onnx_export():
+ img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+ x.device).float()
+ else:
+ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+ x.device).float()
+ return img_shape
+
+
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
+ """Convert image based absolute point coordinates to image based relative
+ coordinates for sampling.
+
+ Args:
+ abs_img_points (Tensor): Image based absolute point coordinates,
+ shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ assert (isinstance(img, tuple) and len(img) == 2) or \
+ (isinstance(img, torch.Tensor) and len(img.shape) == 4)
+
+ if isinstance(img, tuple):
+ h, w = img
+ scale = torch.tensor([w, h],
+ dtype=torch.float,
+ device=abs_img_points.device)
+ scale = scale.view(1, 1, 2)
+ else:
+ scale = get_shape_from_feature_map(img)
+
+ return abs_img_points / scale * spatial_scale
+
+
+def rel_roi_point_to_rel_img_point(rois,
+ rel_roi_points,
+ img,
+ spatial_scale=1.):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
+ rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
+ spatial_scale)
+
+ return rel_img_point
+
+
+def point_sample(input, points, align_corners=False, **kwargs):
+ """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
+ Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
+ lie inside ``[0, 1] x [0, 1]`` square.
+
+ Args:
+ input (Tensor): Feature map, shape (N, C, H, W).
+ points (Tensor): Image based absolute point coordinates (normalized),
+ range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
+ align_corners (bool): Whether align_corners. Default: False
+
+ Returns:
+ Tensor: Features of `point` on `input`, shape (N, C, P) or
+ (N, C, Hgrid, Wgrid).
+ """
+
+ add_dim = False
+ if points.dim() == 3:
+ add_dim = True
+ points = points.unsqueeze(2)
+ if is_in_onnx_export_without_custom_ops():
+ # If custom ops for onnx runtime not compiled use python
+ # implementation of grid_sample function to make onnx graph
+ # with supported nodes
+ output = bilinear_grid_sample(
+ input, denormalize(points), align_corners=align_corners)
+ else:
+ output = F.grid_sample(
+ input, denormalize(points), align_corners=align_corners, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+class SimpleRoIAlign(nn.Module):
+
+ def __init__(self, output_size, spatial_scale, aligned=True):
+ """Simple RoI align in PointRend, faster than standard RoIAlign.
+
+ Args:
+ output_size (tuple[int]): h, w
+ spatial_scale (float): scale the input boxes by this number
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection, align_corners=True will be used in F.grid_sample.
+ If True, align the results more perfectly.
+ """
+
+ super(SimpleRoIAlign, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ # to be consistent with other RoI ops
+ self.use_torchvision = False
+ self.aligned = aligned
+
+ def forward(self, features, rois):
+ num_imgs = features.size(0)
+ num_rois = rois.size(0)
+ rel_roi_points = generate_grid(
+ num_rois, self.output_size, device=rois.device)
+
+ if torch.onnx.is_in_onnx_export():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, features, self.spatial_scale)
+ rel_img_points = rel_img_points.reshape(num_imgs, -1,
+ *rel_img_points.shape[1:])
+ point_feats = point_sample(
+ features, rel_img_points, align_corners=not self.aligned)
+ point_feats = point_feats.transpose(1, 2)
+ else:
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = features[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat,
+ self.spatial_scale).unsqueeze(0)
+ point_feat = point_sample(
+ feat, rel_img_points, align_corners=not self.aligned)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+
+ point_feats = torch.cat(point_feats, dim=0)
+
+ channels = features.size(1)
+ roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
+
+ return roi_feats
+
+ def __repr__(self):
+ format_str = self.__class__.__name__
+ format_str += '(output_size={}, spatial_scale={}'.format(
+ self.output_size, self.spatial_scale)
+ return format_str
diff --git a/src/custom_mmpkg/custom_mmcv/ops/points_in_boxes.py b/src/custom_mmpkg/custom_mmcv/ops/points_in_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/points_in_boxes.py
@@ -0,0 +1,133 @@
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
+ 'points_in_boxes_all_forward'
+])
+
+
+def points_in_boxes_part(points, boxes):
+ """Find the box in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
+ LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points),
+ dtype=torch.int).fill_(-1)
+
+ # If manually put the tensor 'points' or 'boxes' on a device
+ # which is not the current device, some temporary variables
+ # will be created on the current device in the cuda op,
+ # and the output will be incorrect.
+ # Therefore, we force the current device to be the same
+ # as the device of the tensors if it was not.
+ # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
+ # for the incorrect output before the fix.
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_part_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
+
+
+def points_in_boxes_cpu(points, boxes):
+ """Find all boxes in which each point is (CPU). The CPU version of
+ :meth:`points_in_boxes_all`.
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in
+ LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ point_indices = points.new_zeros((batch_size, num_boxes, num_points),
+ dtype=torch.int)
+ for b in range(batch_size):
+ ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
+ points[b].float().contiguous(),
+ point_indices[b])
+ point_indices = point_indices.transpose(1, 2)
+
+ return point_indices
+
+
+def points_in_boxes_all(points, boxes):
+ """Find all boxes in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert boxes.shape[0] == points.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {boxes.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
+ dtype=torch.int).fill_(0)
+
+ # Same reason as line 25-32
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_all_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
diff --git a/src/custom_mmpkg/custom_mmcv/ops/points_sampler.py b/src/custom_mmpkg/custom_mmcv/ops/points_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df321530990289ebfe426434635351b3687dce6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/points_sampler.py
@@ -0,0 +1,177 @@
+from typing import List
+
+import torch
+from torch import nn as nn
+
+from custom_mmpkg.custom_mmcv.runner import force_fp32
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+
+
+def calc_square_dist(point_feat_a, point_feat_b, norm=True):
+ """Calculating square distance between a and b.
+
+ Args:
+ point_feat_a (Tensor): (B, N, C) Feature vector of each point.
+ point_feat_b (Tensor): (B, M, C) Feature vector of each point.
+ norm (Bool, optional): Whether to normalize the distance.
+ Default: True.
+
+ Returns:
+ Tensor: (B, N, M) Distance between each pair points.
+ """
+ num_channel = point_feat_a.shape[-1]
+ # [bs, n, 1]
+ a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
+ # [bs, 1, m]
+ b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
+
+ corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
+
+ dist = a_square + b_square - 2 * corr_matrix
+ if norm:
+ dist = torch.sqrt(dist) / num_channel
+ return dist
+
+
+def get_sampler_cls(sampler_type):
+ """Get the type and mode of points sampler.
+
+ Args:
+ sampler_type (str): The type of points sampler.
+ The valid value are "D-FPS", "F-FPS", or "FS".
+
+ Returns:
+ class: Points sampler type.
+ """
+ sampler_mappings = {
+ 'D-FPS': DFPSSampler,
+ 'F-FPS': FFPSSampler,
+ 'FS': FSSampler,
+ }
+ try:
+ return sampler_mappings[sampler_type]
+ except KeyError:
+ raise KeyError(
+ f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
+ {sampler_type}')
+
+
+class PointsSampler(nn.Module):
+ """Points sampling.
+
+ Args:
+ num_point (list[int]): Number of sample points.
+ fps_mod_list (list[str], optional): Type of FPS method, valid mod
+ ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
+ F-FPS: using feature distances for FPS.
+ D-FPS: using Euclidean distances of points for FPS.
+ FS: using F-FPS and D-FPS simultaneously.
+ fps_sample_range_list (list[int], optional):
+ Range of points to apply FPS. Default: [-1].
+ """
+
+ def __init__(self,
+ num_point: List[int],
+ fps_mod_list: List[str] = ['D-FPS'],
+ fps_sample_range_list: List[int] = [-1]):
+ super().__init__()
+ # FPS would be applied to different fps_mod in the list,
+ # so the length of the num_point should be equal to
+ # fps_mod_list and fps_sample_range_list.
+ assert len(num_point) == len(fps_mod_list) == len(
+ fps_sample_range_list)
+ self.num_point = num_point
+ self.fps_sample_range_list = fps_sample_range_list
+ self.samplers = nn.ModuleList()
+ for fps_mod in fps_mod_list:
+ self.samplers.append(get_sampler_cls(fps_mod)())
+ self.fp16_enabled = False
+
+ @force_fp32()
+ def forward(self, points_xyz, features):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, npoint, sample_num) Indices of sampled points.
+ """
+ indices = []
+ last_fps_end_index = 0
+
+ for fps_sample_range, sampler, npoint in zip(
+ self.fps_sample_range_list, self.samplers, self.num_point):
+ assert fps_sample_range < points_xyz.shape[1]
+
+ if fps_sample_range == -1:
+ sample_points_xyz = points_xyz[:, last_fps_end_index:]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:]
+ else:
+ sample_features = None
+ else:
+ sample_points_xyz = \
+ points_xyz[:, last_fps_end_index:fps_sample_range]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:
+ fps_sample_range]
+ else:
+ sample_features = None
+
+ fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
+ npoint)
+
+ indices.append(fps_idx + last_fps_end_index)
+ last_fps_end_index += fps_sample_range
+ indices = torch.cat(indices, dim=1)
+
+ return indices
+
+
+class DFPSSampler(nn.Module):
+ """Using Euclidean distances of points for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with D-FPS."""
+ fps_idx = furthest_point_sample(points.contiguous(), npoint)
+ return fps_idx
+
+
+class FFPSSampler(nn.Module):
+ """Using feature distances for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with F-FPS."""
+ assert features is not None, \
+ 'feature input to FFPS_Sampler should not be None'
+ features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
+ features_dist = calc_square_dist(
+ features_for_fps, features_for_fps, norm=False)
+ fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
+ return fps_idx
+
+
+class FSSampler(nn.Module):
+ """Using F-FPS and D-FPS simultaneously."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with FS_Sampling."""
+ assert features is not None, \
+ 'feature input to FS_Sampler should not be None'
+ ffps_sampler = FFPSSampler()
+ dfps_sampler = DFPSSampler()
+ fps_idx_ffps = ffps_sampler(points, features, npoint)
+ fps_idx_dfps = dfps_sampler(points, features, npoint)
+ fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
+ return fps_idx
diff --git a/src/custom_mmpkg/custom_mmcv/ops/psa_mask.py b/src/custom_mmpkg/custom_mmcv/ops/psa_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/psa_mask.py
@@ -0,0 +1,92 @@
+# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['psamask_forward', 'psamask_backward'])
+
+
+class PSAMaskFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, psa_type, mask_size):
+ return g.op(
+ 'mmcv::MMCVPSAMask',
+ input,
+ psa_type_i=psa_type,
+ mask_size_i=mask_size)
+
+ @staticmethod
+ def forward(ctx, input, psa_type, mask_size):
+ ctx.psa_type = psa_type
+ ctx.mask_size = _pair(mask_size)
+ ctx.save_for_backward(input)
+
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ assert channels == h_mask * w_mask
+ output = input.new_zeros(
+ (batch_size, h_feature * w_feature, h_feature, w_feature))
+
+ ext_module.psamask_forward(
+ input,
+ output,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors[0]
+ psa_type = ctx.psa_type
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ grad_input = grad_output.new_zeros(
+ (batch_size, channels, h_feature, w_feature))
+ ext_module.psamask_backward(
+ grad_output,
+ grad_input,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return grad_input, None, None, None
+
+
+psa_mask = PSAMaskFunction.apply
+
+
+class PSAMask(nn.Module):
+
+ def __init__(self, psa_type, mask_size=None):
+ super(PSAMask, self).__init__()
+ assert psa_type in ['collect', 'distribute']
+ if psa_type == 'collect':
+ psa_type_enum = 0
+ else:
+ psa_type_enum = 1
+ self.psa_type_enum = psa_type_enum
+ self.mask_size = mask_size
+ self.psa_type = psa_type
+
+ def forward(self, input):
+ return psa_mask(input, self.psa_type_enum, self.mask_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(psa_type={self.psa_type}, '
+ s += f'mask_size={self.mask_size})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/roi_align.py b/src/custom_mmpkg/custom_mmcv/ops/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/roi_align.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import deprecated_api_warning, ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_align_forward', 'roi_align_backward'])
+
+
+class RoIAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
+ pool_mode, aligned):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ if has_custom_op:
+ return g.op(
+ 'mmcv::MMCVRoiAlign',
+ input,
+ rois,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sampling_ratio,
+ mode_s=pool_mode,
+ aligned_i=aligned)
+ else:
+ from torch.onnx.symbolic_opset9 import sub, squeeze
+ from torch.onnx.symbolic_helper import _slice_helper
+ from torch.onnx import TensorProtoDataType
+ # batch_indices = rois[:, 0].long()
+ batch_indices = _slice_helper(
+ g, rois, axes=[1], starts=[0], ends=[1])
+ batch_indices = squeeze(g, batch_indices, 1)
+ batch_indices = g.op(
+ 'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
+ # rois = rois[:, 1:]
+ rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
+ if aligned:
+ # rois -= 0.5/spatial_scale
+ aligned_offset = g.op(
+ 'Constant',
+ value_t=torch.tensor([0.5 / spatial_scale],
+ dtype=torch.float32))
+ rois = sub(g, rois, aligned_offset)
+ # roi align
+ return g.op(
+ 'RoiAlign',
+ input,
+ rois,
+ batch_indices,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=max(0, sampling_ratio),
+ mode_s=pool_mode)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.sampling_ratio = sampling_ratio
+ assert pool_mode in ('max', 'avg')
+ ctx.pool_mode = 0 if pool_mode == 'max' else 1
+ ctx.aligned = aligned
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ if ctx.pool_mode == 0:
+ argmax_y = input.new_zeros(output_shape)
+ argmax_x = input.new_zeros(output_shape)
+ else:
+ argmax_y = input.new_zeros(0)
+ argmax_x = input.new_zeros(0)
+
+ ext_module.roi_align_forward(
+ input,
+ rois,
+ output,
+ argmax_y,
+ argmax_x,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+
+ ctx.save_for_backward(rois, argmax_y, argmax_x)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax_y, argmax_x = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous.
+ grad_output = grad_output.contiguous()
+ ext_module.roi_align_backward(
+ grad_output,
+ rois,
+ argmax_y,
+ argmax_x,
+ grad_input,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+ return grad_input, None, None, None, None, None, None
+
+
+roi_align = RoIAlignFunction.apply
+
+
+class RoIAlign(nn.Module):
+ """RoI align pooling layer.
+
+ Args:
+ output_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sampling_ratio (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ use_torchvision (bool): whether to use roi_align from torchvision.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'out_size': 'output_size',
+ 'sample_num': 'sampling_ratio'
+ },
+ cls_name='RoIAlign')
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True,
+ use_torchvision=False):
+ super(RoIAlign, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.pool_mode = pool_mode
+ self.aligned = aligned
+ self.use_torchvision = use_torchvision
+
+ def forward(self, input, rois):
+ """
+ Args:
+ input: NCHW images
+ rois: Bx5 boxes. First column is the index into N.\
+ The other 4 columns are xyxy.
+ """
+ if self.use_torchvision:
+ from torchvision.ops import roi_align as tv_roi_align
+ if 'aligned' in tv_roi_align.__code__.co_varnames:
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.aligned)
+ else:
+ if self.aligned:
+ rois -= rois.new_tensor([0.] +
+ [0.5 / self.spatial_scale] * 4)
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio)
+ else:
+ return roi_align(input, rois, self.output_size, self.spatial_scale,
+ self.sampling_ratio, self.pool_mode, self.aligned)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale}, '
+ s += f'sampling_ratio={self.sampling_ratio}, '
+ s += f'pool_mode={self.pool_mode}, '
+ s += f'aligned={self.aligned}, '
+ s += f'use_torchvision={self.use_torchvision})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/roi_align_rotated.py b/src/custom_mmpkg/custom_mmcv/ops/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/roi_align_rotated.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
+
+
+class RoIAlignRotatedFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, rois, out_size, spatial_scale, sample_num,
+ aligned, clockwise):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ return g.op(
+ 'mmcv::MMCVRoIAlignRotated',
+ features,
+ rois,
+ output_height_i=out_h,
+ output_width_i=out_h,
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sample_num,
+ aligned_i=aligned,
+ clockwise_i=clockwise)
+
+ @staticmethod
+ def forward(ctx,
+ features,
+ rois,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ ctx.spatial_scale = spatial_scale
+ ctx.sample_num = sample_num
+ ctx.aligned = aligned
+ ctx.clockwise = clockwise
+ ctx.save_for_backward(rois)
+ ctx.feature_size = features.size()
+
+ batch_size, num_channels, data_height, data_width = features.size()
+ num_rois = rois.size(0)
+
+ output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+ ext_module.roi_align_rotated_forward(
+ features,
+ rois,
+ output,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ feature_size = ctx.feature_size
+ spatial_scale = ctx.spatial_scale
+ aligned = ctx.aligned
+ clockwise = ctx.clockwise
+ sample_num = ctx.sample_num
+ rois = ctx.saved_tensors[0]
+ assert feature_size is not None
+ batch_size, num_channels, data_height, data_width = feature_size
+
+ out_w = grad_output.size(3)
+ out_h = grad_output.size(2)
+
+ grad_input = grad_rois = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = rois.new_zeros(batch_size, num_channels, data_height,
+ data_width)
+ ext_module.roi_align_rotated_backward(
+ grad_output.contiguous(),
+ rois,
+ grad_input,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return grad_input, grad_rois, None, None, None, None, None
+
+
+roi_align_rotated = RoIAlignRotatedFunction.apply
+
+
+class RoIAlignRotated(nn.Module):
+ """RoI align pooling layer for rotated proposals.
+
+ It accepts a feature map of shape (N, C, H, W) and rois with shape
+ (n, 6) with each roi decoded as (batch_index, center_x, center_y,
+ w, h, angle). The angle is in radian.
+
+ Args:
+ out_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sample_num (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ Default: True.
+ clockwise (bool): If True, the angle in each proposal follows a
+ clockwise fashion in image space, otherwise, the angle is
+ counterclockwise. Default: False.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ def __init__(self,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ super(RoIAlignRotated, self).__init__()
+
+ self.out_size = out_size
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.aligned = aligned
+ self.clockwise = clockwise
+
+ def forward(self, features, rois):
+ return RoIAlignRotatedFunction.apply(features, rois, self.out_size,
+ self.spatial_scale,
+ self.sample_num, self.aligned,
+ self.clockwise)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/roi_pool.py b/src/custom_mmpkg/custom_mmcv/ops/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/roi_pool.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_pool_forward', 'roi_pool_backward'])
+
+
+class RoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale):
+ return g.op(
+ 'MaxRoiPool',
+ input,
+ rois,
+ pooled_shape_i=output_size,
+ spatial_scale_f=spatial_scale)
+
+ @staticmethod
+ def forward(ctx, input, rois, output_size, spatial_scale=1.0):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ argmax = input.new_zeros(output_shape, dtype=torch.int)
+
+ ext_module.roi_pool_forward(
+ input,
+ rois,
+ output,
+ argmax,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ ctx.save_for_backward(rois, argmax)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+
+ ext_module.roi_pool_backward(
+ grad_output,
+ rois,
+ argmax,
+ grad_input,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ return grad_input, None, None, None
+
+
+roi_pool = RoIPoolFunction.apply
+
+
+class RoIPool(nn.Module):
+
+ def __init__(self, output_size, spatial_scale=1.0):
+ super(RoIPool, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+
+ def forward(self, input, rois):
+ return roi_pool(input, rois, self.output_size, self.spatial_scale)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/roiaware_pool3d.py b/src/custom_mmpkg/custom_mmcv/ops/roiaware_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d8a4d7f99181f224bda079ff7487aae5b92383
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/roiaware_pool3d.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+import custom_mmpkg.custom_mmcv as mmcv
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
+
+
+class RoIAwarePool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `PartA2 `_ for more
+ details.
+
+ Args:
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int, optional): The maximum number of points per
+ voxel. Default: 128.
+ mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
+ Default: 'max'.
+ """
+
+ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
+ super().__init__()
+
+ self.out_size = out_size
+ self.max_pts_per_voxel = max_pts_per_voxel
+ assert mode in ['max', 'avg']
+ pool_mapping = {'max': 0, 'avg': 1}
+ self.mode = pool_mapping[mode]
+
+ def forward(self, rois, pts, pts_feature):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
+ """
+
+ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
+ self.out_size,
+ self.max_pts_per_voxel, self.mode)
+
+
+class RoIAwarePool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
+ mode):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int): The maximum number of points per voxel.
+ Default: 128.
+ mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
+ pool).
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
+ pooled features.
+ """
+
+ if isinstance(out_size, int):
+ out_x = out_y = out_z = out_size
+ else:
+ assert len(out_size) == 3
+ assert mmcv.is_tuple_of(out_size, int)
+ out_x, out_y, out_z = out_size
+
+ num_rois = rois.shape[0]
+ num_channels = pts_feature.shape[-1]
+ num_pts = pts.shape[0]
+
+ pooled_features = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels))
+ argmax = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
+ pts_idx_of_voxels = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
+ dtype=torch.int)
+
+ ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
+ pts_idx_of_voxels, pooled_features,
+ mode)
+
+ ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
+ num_pts, num_channels)
+ return pooled_features
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ ret = ctx.roiaware_pool3d_for_backward
+ pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
+
+ grad_in = grad_out.new_zeros((num_pts, num_channels))
+ ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
+ grad_out.contiguous(), grad_in,
+ mode)
+
+ return None, None, grad_in, None, None, None
diff --git a/src/custom_mmpkg/custom_mmcv/ops/roipoint_pool3d.py b/src/custom_mmpkg/custom_mmcv/ops/roipoint_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/roipoint_pool3d.py
@@ -0,0 +1,77 @@
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])
+
+
+class RoIPointPool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `Paper of PartA2 `_
+ for more details.
+
+ Args:
+ num_sampled_points (int, optional): Number of samples in each roi.
+ Default: 512.
+ """
+
+ def __init__(self, num_sampled_points=512):
+ super().__init__()
+ self.num_sampled_points = num_sampled_points
+
+ def forward(self, points, point_features, boxes3d):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
+ self.num_sampled_points)
+
+
+class RoIPointPool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+ num_sampled_points (int, optional): The num of sampled points.
+ Default: 512.
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ assert len(points.shape) == 3 and points.shape[2] == 3
+ batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
+ 1], point_features.shape[2]
+ pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
+ pooled_features = point_features.new_zeros(
+ (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
+ pooled_empty_flag = point_features.new_zeros(
+ (batch_size, boxes_num)).int()
+
+ ext_module.roipoint_pool3d_forward(points.contiguous(),
+ pooled_boxes3d.contiguous(),
+ point_features.contiguous(),
+ pooled_features, pooled_empty_flag)
+
+ return pooled_features, pooled_empty_flag
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ raise NotImplementedError
diff --git a/src/custom_mmpkg/custom_mmcv/ops/saconv.py b/src/custom_mmpkg/custom_mmcv/ops/saconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d26992534cba3ba0ee36f08b700c5489fea30d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/saconv.py
@@ -0,0 +1,145 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from custom_mmpkg.custom_mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
+from custom_mmpkg.custom_mmcv.ops.deform_conv import deform_conv2d
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, digit_version
+
+
+@CONV_LAYERS.register_module(name='SAC')
+class SAConv2d(ConvAWS2d):
+ """SAC (Switchable Atrous Convolution)
+
+ This is an implementation of SAC in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf).
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ use_deform: If ``True``, replace convolution with deformable
+ convolution. Default: ``False``.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ use_deform=False):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.use_deform = use_deform
+ self.switch = nn.Conv2d(
+ self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
+ self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.pre_context = nn.Conv2d(
+ self.in_channels, self.in_channels, kernel_size=1, bias=True)
+ self.post_context = nn.Conv2d(
+ self.out_channels, self.out_channels, kernel_size=1, bias=True)
+ if self.use_deform:
+ self.offset_s = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.offset_l = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ constant_init(self.switch, 0, bias=1)
+ self.weight_diff.data.zero_()
+ constant_init(self.pre_context, 0)
+ constant_init(self.post_context, 0)
+ if self.use_deform:
+ constant_init(self.offset_s, 0)
+ constant_init(self.offset_l, 0)
+
+ def forward(self, x):
+ # pre-context
+ avg_x = F.adaptive_avg_pool2d(x, output_size=1)
+ avg_x = self.pre_context(avg_x)
+ avg_x = avg_x.expand_as(x)
+ x = x + avg_x
+ # switch
+ avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
+ avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
+ switch = self.switch(avg_x)
+ # sac
+ weight = self._get_weight(self.weight)
+ zero_bias = torch.zeros(
+ self.out_channels, device=weight.device, dtype=weight.dtype)
+
+ if self.use_deform:
+ offset = self.offset_s(avg_x)
+ out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_s = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_s = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_s = super()._conv_forward(x, weight)
+ ori_p = self.padding
+ ori_d = self.dilation
+ self.padding = tuple(3 * p for p in self.padding)
+ self.dilation = tuple(3 * d for d in self.dilation)
+ weight = weight + self.weight_diff
+ if self.use_deform:
+ offset = self.offset_l(avg_x)
+ out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_l = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_l = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_l = super()._conv_forward(x, weight)
+
+ out = switch * out_s + (1 - switch) * out_l
+ self.padding = ori_p
+ self.dilation = ori_d
+ # post-context
+ avg_x = F.adaptive_avg_pool2d(out, output_size=1)
+ avg_x = self.post_context(avg_x)
+ avg_x = avg_x.expand_as(out)
+ out = out + avg_x
+ return out
diff --git a/src/custom_mmpkg/custom_mmcv/ops/scatter_points.py b/src/custom_mmpkg/custom_mmcv/ops/scatter_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/scatter_points.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
+
+
+class _DynamicScatter(Function):
+
+ @staticmethod
+ def forward(ctx, feats, coors, reduce_type='max'):
+ """convert kitti points(N, >=3) to voxels.
+
+ Args:
+ feats (torch.Tensor): [N, C]. Points features to be reduced
+ into voxels.
+ coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
+ (specifically multi-dim voxel index) of each points.
+ reduce_type (str, optional): Reduce op. support 'max', 'sum' and
+ 'mean'. Default: 'max'.
+
+ Returns:
+ voxel_feats (torch.Tensor): [M, C]. Reduced features, input
+ features that shares the same voxel coordinates are reduced to
+ one row.
+ voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
+ """
+ results = ext_module.dynamic_point_to_voxel_forward(
+ feats, coors, reduce_type)
+ (voxel_feats, voxel_coors, point2voxel_map,
+ voxel_points_count) = results
+ ctx.reduce_type = reduce_type
+ ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
+ voxel_points_count)
+ ctx.mark_non_differentiable(voxel_coors)
+ return voxel_feats, voxel_coors
+
+ @staticmethod
+ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
+ (feats, voxel_feats, point2voxel_map,
+ voxel_points_count) = ctx.saved_tensors
+ grad_feats = torch.zeros_like(feats)
+ # TODO: whether to use index put or use cuda_backward
+ # To use index put, need point to voxel index
+ ext_module.dynamic_point_to_voxel_backward(
+ grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
+ point2voxel_map, voxel_points_count, ctx.reduce_type)
+ return grad_feats, None, None
+
+
+dynamic_scatter = _DynamicScatter.apply
+
+
+class DynamicScatter(nn.Module):
+ """Scatters points into voxels, used in the voxel encoder with dynamic
+ voxelization.
+
+ Note:
+ The CPU and GPU implementation get the same output, but have numerical
+ difference after summation and division (e.g., 5e-7).
+
+ Args:
+ voxel_size (list): list [x, y, z] size of three dimension.
+ point_cloud_range (list): The coordinate range of points, [x_min,
+ y_min, z_min, x_max, y_max, z_max].
+ average_points (bool): whether to use avg pooling to scatter points
+ into voxel.
+ """
+
+ def __init__(self, voxel_size, point_cloud_range, average_points: bool):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.average_points = average_points
+
+ def forward_single(self, points, coors):
+ """Scatters points into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ reduce = 'mean' if self.average_points else 'max'
+ return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
+
+ def forward(self, points, coors):
+ """Scatters points/features into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ if coors.size(-1) == 3:
+ return self.forward_single(points, coors)
+ else:
+ batch_size = coors[-1, 0] + 1
+ voxels, voxel_coors = [], []
+ for i in range(batch_size):
+ inds = torch.where(coors[:, 0] == i)
+ voxel, voxel_coor = self.forward_single(
+ points[inds], coors[inds][:, 1:])
+ coor_pad = nn.functional.pad(
+ voxel_coor, (1, 0), mode='constant', value=i)
+ voxel_coors.append(coor_pad)
+ voxels.append(voxel)
+ features = torch.cat(voxels, dim=0)
+ feature_coors = torch.cat(voxel_coors, dim=0)
+
+ return features, feature_coors
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', average_points=' + str(self.average_points)
+ s += ')'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/sync_bn.py b/src/custom_mmpkg/custom_mmcv/ops/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f885caac860ae7197ba2a29433b3c3debfdb2e65
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/sync_bn.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.module import Module
+from torch.nn.parameter import Parameter
+
+from custom_mmpkg.custom_mmcv.cnn import NORM_LAYERS
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
+ 'sync_bn_backward_param', 'sync_bn_backward_data'
+])
+
+
+class SyncBatchNormFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ return g.op(
+ 'mmcv::MMCVSyncBatchNorm',
+ input,
+ running_mean,
+ running_var,
+ weight,
+ bias,
+ momentum_f=momentum,
+ eps_f=eps,
+ group_i=group,
+ group_size_i=group_size,
+ stats_mode=stats_mode)
+
+ @staticmethod
+ def forward(self, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ self.momentum = momentum
+ self.eps = eps
+ self.group = group
+ self.group_size = group_size
+ self.stats_mode = stats_mode
+
+ assert isinstance(
+ input, (torch.HalfTensor, torch.FloatTensor,
+ torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+ f'only support Half or Float Tensor, but {input.type()}'
+ output = torch.zeros_like(input)
+ input3d = input.flatten(start_dim=2)
+ output3d = output.view_as(input3d)
+ num_channels = input3d.size(1)
+
+ # ensure mean/var/norm/std are initialized as zeros
+ # ``torch.empty()`` does not guarantee that
+ mean = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ var = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ norm = torch.zeros_like(
+ input3d, dtype=torch.float, device=input3d.device)
+ std = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+
+ batch_size = input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_forward_mean(input3d, mean)
+ batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
+ else:
+ # skip updating mean and leave it as zeros when the input is empty
+ batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
+
+ # synchronize mean and the batch flag
+ vec = torch.cat([mean, batch_flag])
+ if self.stats_mode == 'N':
+ vec *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(vec, group=self.group)
+ total_batch = vec[-1].detach()
+ mean = vec[:num_channels]
+
+ if self.stats_mode == 'default':
+ mean = mean / self.group_size
+ elif self.stats_mode == 'N':
+ mean = mean / total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # leave var as zeros when the input is empty
+ if batch_size > 0:
+ ext_module.sync_bn_forward_var(input3d, mean, var)
+
+ if self.stats_mode == 'N':
+ var *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(var, group=self.group)
+
+ if self.stats_mode == 'default':
+ var /= self.group_size
+ elif self.stats_mode == 'N':
+ var /= total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # if the total batch size over all the ranks is zero,
+ # we should not update the statistics in the current batch
+ update_flag = total_batch.clamp(max=1)
+ momentum = update_flag * self.momentum
+ ext_module.sync_bn_forward_output(
+ input3d,
+ mean,
+ var,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ norm,
+ std,
+ output3d,
+ eps=self.eps,
+ momentum=momentum,
+ group_size=self.group_size)
+ self.save_for_backward(norm, std, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(self, grad_output):
+ norm, std, weight = self.saved_tensors
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(weight)
+ grad_input = torch.zeros_like(grad_output)
+ grad_output3d = grad_output.flatten(start_dim=2)
+ grad_input3d = grad_input.view_as(grad_output3d)
+
+ batch_size = grad_input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
+ grad_bias)
+
+ # all reduce
+ if self.group_size > 1:
+ dist.all_reduce(grad_weight, group=self.group)
+ dist.all_reduce(grad_bias, group=self.group)
+ grad_weight /= self.group_size
+ grad_bias /= self.group_size
+
+ if batch_size > 0:
+ ext_module.sync_bn_backward_data(grad_output3d, weight,
+ grad_weight, grad_bias, norm, std,
+ grad_input3d)
+
+ return grad_input, None, None, grad_weight, grad_bias, \
+ None, None, None, None, None
+
+
+@NORM_LAYERS.register_module(name='MMSyncBN')
+class SyncBatchNorm(Module):
+ """Synchronized Batch Normalization.
+
+ Args:
+ num_features (int): number of features/chennels in input tensor
+ eps (float, optional): a value added to the denominator for numerical
+ stability. Defaults to 1e-5.
+ momentum (float, optional): the value used for the running_mean and
+ running_var computation. Defaults to 0.1.
+ affine (bool, optional): whether to use learnable affine parameters.
+ Defaults to True.
+ track_running_stats (bool, optional): whether to track the running
+ mean and variance during training. When set to False, this
+ module does not track such statistics, and initializes statistics
+ buffers ``running_mean`` and ``running_var`` as ``None``. When
+ these buffers are ``None``, this module always uses batch
+ statistics in both training and eval modes. Defaults to True.
+ group (int, optional): synchronization of stats happen within
+ each process group individually. By default it is synchronization
+ across the whole world. Defaults to None.
+ stats_mode (str, optional): The statistical mode. Available options
+ includes ``'default'`` and ``'N'``. Defaults to 'default'.
+ When ``stats_mode=='default'``, it computes the overall statistics
+ using those from each worker with equal weight, i.e., the
+ statistics are synchronized and simply divied by ``group``. This
+ mode will produce inaccurate statistics when empty tensors occur.
+ When ``stats_mode=='N'``, it compute the overall statistics using
+ the total number of batches in each worker ignoring the number of
+ group, i.e., the statistics are synchronized and then divied by
+ the total batch ``N``. This mode is beneficial when empty tensors
+ occur during training, as it average the total mean by the real
+ number of batch.
+ """
+
+ def __init__(self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ group=None,
+ stats_mode='default'):
+ super(SyncBatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ group = dist.group.WORLD if group is None else group
+ self.group = group
+ self.group_size = dist.get_world_size(group)
+ assert stats_mode in ['default', 'N'], \
+ f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
+ self.stats_mode = stats_mode
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked',
+ torch.tensor(0, dtype=torch.long))
+ else:
+ self.register_buffer('running_mean', None)
+ self.register_buffer('running_var', None)
+ self.register_buffer('num_batches_tracked', None)
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ self.weight.data.uniform_() # pytorch use ones_()
+ self.bias.data.zero_()
+
+ def forward(self, input):
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input, got {input.dim()}D input')
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(
+ self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ if self.training or not self.track_running_stats:
+ return SyncBatchNormFunction.apply(
+ input, self.running_mean, self.running_var, self.weight,
+ self.bias, exponential_average_factor, self.eps, self.group,
+ self.group_size, self.stats_mode)
+ else:
+ return F.batch_norm(input, self.running_mean, self.running_var,
+ self.weight, self.bias, False,
+ exponential_average_factor, self.eps)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'({self.num_features}, '
+ s += f'eps={self.eps}, '
+ s += f'momentum={self.momentum}, '
+ s += f'affine={self.affine}, '
+ s += f'track_running_stats={self.track_running_stats}, '
+ s += f'group_size={self.group_size},'
+ s += f'stats_mode={self.stats_mode})'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/ops/three_interpolate.py b/src/custom_mmpkg/custom_mmcv/ops/three_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/three_interpolate.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
+
+
+class ThreeInterpolate(Function):
+ """Performs weighted linear interpolation on 3 features.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
+ weight: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, M) Features descriptors to be
+ interpolated
+ indices (Tensor): (B, n, 3) index three nearest neighbors
+ of the target features in features
+ weight (Tensor): (B, n, 3) weights of interpolation
+
+ Returns:
+ Tensor: (B, C, N) tensor of the interpolated features
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+ assert weight.is_contiguous()
+
+ B, c, m = features.size()
+ n = indices.size(1)
+ ctx.three_interpolate_for_backward = (indices, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+
+ ext_module.three_interpolate_forward(
+ features, indices, weight, output, b=B, c=c, m=m, n=n)
+ return output
+
+ @staticmethod
+ def backward(
+ ctx, grad_out: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, N) tensor with gradients of outputs
+
+ Returns:
+ Tensor: (B, C, M) tensor with gradients of features
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+
+ grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+ grad_out_data = grad_out.data.contiguous()
+
+ ext_module.three_interpolate_backward(
+ grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
+ return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/three_nn.py b/src/custom_mmpkg/custom_mmcv/ops/three_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/three_nn.py
@@ -0,0 +1,51 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
+
+
+class ThreeNN(Function):
+ """Find the top-3 nearest neighbors of the target set from the source set.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, target: torch.Tensor,
+ source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ target (Tensor): shape (B, N, 3), points set that needs to
+ find the nearest neighbors.
+ source (Tensor): shape (B, M, 3), points set that is used
+ to find the nearest neighbors of points in target set.
+
+ Returns:
+ Tensor: shape (B, N, 3), L2 distance of each point in target
+ set to their corresponding nearest neighbors.
+ """
+ target = target.contiguous()
+ source = source.contiguous()
+
+ B, N, _ = target.size()
+ m = source.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+
+ ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+
+
+three_nn = ThreeNN.apply
diff --git a/src/custom_mmpkg/custom_mmcv/ops/tin_shift.py b/src/custom_mmpkg/custom_mmcv/ops/tin_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/tin_shift.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code reference from "Temporal Interlacing Network"
+# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
+# Hao Shao, Shengju Qian, Yu Liu
+# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['tin_shift_forward', 'tin_shift_backward'])
+
+
+class TINShiftFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, shift):
+ C = input.size(2)
+ num_segments = shift.size(1)
+ if C // num_segments <= 0 or C % num_segments != 0:
+ raise ValueError('C should be a multiple of num_segments, '
+ f'but got C={C} and num_segments={num_segments}.')
+
+ ctx.save_for_backward(shift)
+
+ out = torch.zeros_like(input)
+ ext_module.tin_shift_forward(input, shift, out)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ shift = ctx.saved_tensors[0]
+ data_grad_input = grad_output.new(*grad_output.size()).zero_()
+ shift_grad_input = shift.new(*shift.size()).zero_()
+ ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
+
+ return data_grad_input, shift_grad_input
+
+
+tin_shift = TINShiftFunction.apply
+
+
+class TINShift(nn.Module):
+ """Temporal Interlace Shift.
+
+ Temporal Interlace shift is a differentiable temporal-wise frame shifting
+ which is proposed in "Temporal Interlacing Network"
+
+ Please refer to https://arxiv.org/abs/2001.06499 for more details.
+ Code is modified from https://github.com/mit-han-lab/temporal-shift-module
+ """
+
+ def forward(self, input, shift):
+ """Perform temporal interlace shift.
+
+ Args:
+ input (Tensor): Feature map with shape [N, num_segments, C, H * W].
+ shift (Tensor): Shift tensor with shape [N, num_segments].
+
+ Returns:
+ Feature map after temporal interlace shift.
+ """
+ return tin_shift(input, shift)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/upfirdn2d.py b/src/custom_mmpkg/custom_mmcv/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef4a5236dda57340017f0e16857bca297d4e1b2f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/upfirdn2d.py
@@ -0,0 +1,330 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+from custom_mmpkg.custom_mmcv.utils import to_2tuple
+from ..utils import ext_loader
+
+upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
+ in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ up_x=down_x,
+ up_y=down_y,
+ down_x=up_x,
+ down_y=up_y,
+ pad_x0=g_pad_x0,
+ pad_x1=g_pad_x1,
+ pad_y0=g_pad_y0,
+ pad_y1=g_pad_y1)
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
+ in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
+ ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ up_x=ctx.up_x,
+ up_y=ctx.up_y,
+ down_x=ctx.down_x,
+ down_y=ctx.down_y,
+ pad_x0=ctx.pad_x0,
+ pad_x1=ctx.pad_x1,
+ pad_y0=ctx.pad_y0,
+ pad_y1=ctx.pad_y1)
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
+ ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(
+ input,
+ kernel,
+ up_x=up_x,
+ up_y=up_y,
+ down_x=down_x,
+ down_y=down_y,
+ pad_x0=pad_x0,
+ pad_x1=pad_x1,
+ pad_y0=pad_y0,
+ pad_y1=pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ """UpFRIDn for 2d features.
+
+ UpFIRDn is short for upsample, apply FIR filter and downsample. More
+ details can be found in:
+ https://www.mathworks.com/help/signal/ref/upfirdn.html
+
+ Args:
+ input (Tensor): Tensor with shape of (n, c, h, w).
+ kernel (Tensor): Filter kernel.
+ up (int | tuple[int], optional): Upsampling factor. If given a number,
+ we will use this factor for the both height and width side.
+ Defaults to 1.
+ down (int | tuple[int], optional): Downsampling factor. If given a
+ number, we will use this factor for the both height and width side.
+ Defaults to 1.
+ pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
+ (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
+
+ Returns:
+ Tensor: Tensor after UpFIRDn.
+ """
+ if input.device.type == 'cpu':
+ if len(pad) == 2:
+ pad = (pad[0], pad[1], pad[0], pad[1])
+
+ up = to_2tuple(up)
+
+ down = to_2tuple(down)
+
+ out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
+ pad[0], pad[1], pad[2], pad[3])
+ else:
+ _up = to_2tuple(up)
+
+ _down = to_2tuple(down)
+
+ if len(pad) == 4:
+ _pad = pad
+ elif len(pad) == 2:
+ _pad = (pad[0], pad[1], pad[0], pad[1])
+
+ out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
+ pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out,
+ [0, 0,
+ max(pad_x0, 0),
+ max(pad_x1, 0),
+ max(pad_y0, 0),
+ max(pad_y1, 0)])
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/src/custom_mmpkg/custom_mmcv/ops/voxelize.py b/src/custom_mmpkg/custom_mmcv/ops/voxelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/ops/voxelize.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
+
+
+class _Voxelization(Function):
+
+ @staticmethod
+ def forward(ctx,
+ points,
+ voxel_size,
+ coors_range,
+ max_points=35,
+ max_voxels=20000):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Args:
+ points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
+ and points[:, 3:] contain other information like reflectivity.
+ voxel_size (tuple or float): The size of voxel with the shape of
+ [3].
+ coors_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_points (int, optional): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize. Default: 35.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+
+ Returns:
+ voxels_out (torch.Tensor): Output voxels with the shape of [M,
+ max_points, ndim]. Only contain points and returned when
+ max_points != -1.
+ coors_out (torch.Tensor): Output coordinates with the shape of
+ [M, 3].
+ num_points_per_voxel_out (torch.Tensor): Num points per voxel with
+ the shape of [M]. Only returned when max_points != -1.
+ """
+ if max_points == -1 or max_voxels == -1:
+ coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
+ ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
+ coors_range, 3)
+ return coors
+ else:
+ voxels = points.new_zeros(
+ size=(max_voxels, max_points, points.size(1)))
+ coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
+ num_points_per_voxel = points.new_zeros(
+ size=(max_voxels, ), dtype=torch.int)
+ voxel_num = ext_module.hard_voxelize_forward(
+ points, voxels, coors, num_points_per_voxel, voxel_size,
+ coors_range, max_points, max_voxels, 3)
+ # select the valid voxels
+ voxels_out = voxels[:voxel_num]
+ coors_out = coors[:voxel_num]
+ num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
+ return voxels_out, coors_out, num_points_per_voxel_out
+
+
+voxelization = _Voxelization.apply
+
+
+class Voxelization(nn.Module):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Please refer to `PVCNN `_ for more
+ details.
+
+ Args:
+ voxel_size (tuple or float): The size of voxel with the shape of [3].
+ point_cloud_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_num_points (int): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+ """
+
+ def __init__(self,
+ voxel_size,
+ point_cloud_range,
+ max_num_points,
+ max_voxels=20000):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.max_num_points = max_num_points
+ if isinstance(max_voxels, tuple):
+ self.max_voxels = max_voxels
+ else:
+ self.max_voxels = _pair(max_voxels)
+
+ point_cloud_range = torch.tensor(
+ point_cloud_range, dtype=torch.float32)
+ voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
+ grid_size = (point_cloud_range[3:] -
+ point_cloud_range[:3]) / voxel_size
+ grid_size = torch.round(grid_size).long()
+ input_feat_shape = grid_size[:2]
+ self.grid_size = grid_size
+ # the origin shape is as [x-len, y-len, z-len]
+ # [w, h, d] -> [d, h, w]
+ self.pcd_shape = [*input_feat_shape, 1][::-1]
+
+ def forward(self, input):
+ if self.training:
+ max_voxels = self.max_voxels[0]
+ else:
+ max_voxels = self.max_voxels[1]
+
+ return voxelization(input, self.voxel_size, self.point_cloud_range,
+ self.max_num_points, max_voxels)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', max_num_points=' + str(self.max_num_points)
+ s += ', max_voxels=' + str(self.max_voxels)
+ s += ')'
+ return s
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/__init__.py b/src/custom_mmpkg/custom_mmcv/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .collate import collate
+from .data_container import DataContainer
+from .data_parallel import MMDataParallel
+from .distributed import MMDistributedDataParallel
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import is_module_wrapper
+
+__all__ = [
+ 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
+ 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/_functions.py b/src/custom_mmpkg/custom_mmcv/parallel/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/_functions.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import _get_stream
+
+
+def scatter(input, devices, streams=None):
+ """Scatters tensor across multiple GPUs."""
+ if streams is None:
+ streams = [None] * len(devices)
+
+ if isinstance(input, list):
+ chunk_size = (len(input) - 1) // len(devices) + 1
+ outputs = [
+ scatter(input[i], [devices[i // chunk_size]],
+ [streams[i // chunk_size]]) for i in range(len(input))
+ ]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ # TODO: copy to a pinned buffer first (if copying from CPU)
+ stream = streams[0] if output.numel() > 0 else None
+ if devices != [-1]:
+ with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+ output = output.cuda(devices[0], non_blocking=True)
+ else:
+ # unsqueeze the first dimension thus the tensor's shape is the
+ # same as those scattered with GPU.
+ output = output.unsqueeze(0)
+ return output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+def synchronize_stream(output, devices, streams):
+ if isinstance(output, list):
+ chunk_size = len(output) // len(devices)
+ for i in range(len(devices)):
+ for j in range(chunk_size):
+ synchronize_stream(output[i * chunk_size + j], [devices[i]],
+ [streams[i]])
+ elif isinstance(output, torch.Tensor):
+ if output.numel() != 0:
+ with torch.cuda.device(devices[0]):
+ main_stream = torch.cuda.current_stream()
+ main_stream.wait_stream(streams[0])
+ output.record_stream(main_stream)
+ else:
+ raise Exception(f'Unknown type {type(output)}.')
+
+
+def get_input_device(input):
+ if isinstance(input, list):
+ for item in input:
+ input_device = get_input_device(item)
+ if input_device != -1:
+ return input_device
+ return -1
+ elif isinstance(input, torch.Tensor):
+ return input.get_device() if input.is_cuda else -1
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+ @staticmethod
+ def forward(target_gpus, input):
+ input_device = get_input_device(input)
+ streams = None
+ if input_device == -1 and target_gpus != [-1]:
+ # Perform CPU to GPU copies in a background stream
+ streams = [_get_stream(device) for device in target_gpus]
+
+ outputs = scatter(input, target_gpus, streams)
+ # Synchronize with the copy stream
+ if streams is not None:
+ synchronize_stream(outputs, target_gpus, streams)
+
+ return tuple(outputs)
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/collate.py b/src/custom_mmpkg/custom_mmcv/parallel/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/collate.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+
+from .data_container import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+ """Puts each data field into a tensor/DataContainer with outer dimension
+ batch size.
+
+ Extend default_collate to add support for
+ :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
+
+ 1. cpu_only = True, e.g., meta data
+ 2. cpu_only = False, stack = True, e.g., images tensors
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
+ """
+
+ if not isinstance(batch, Sequence):
+ raise TypeError(f'{batch.dtype} is not supported.')
+
+ if isinstance(batch[0], DataContainer):
+ stacked = []
+ if batch[0].cpu_only:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
+ elif batch[0].stack:
+ for i in range(0, len(batch), samples_per_gpu):
+ assert isinstance(batch[i].data, torch.Tensor)
+
+ if batch[i].pad_dims is not None:
+ ndim = batch[i].dim()
+ assert ndim > batch[i].pad_dims
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = batch[i].size(-dim)
+ for sample in batch[i:i + samples_per_gpu]:
+ for dim in range(0, ndim - batch[i].pad_dims):
+ assert batch[i].size(dim) == sample.size(dim)
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = max(max_shape[dim - 1],
+ sample.size(-dim))
+ padded_samples = []
+ for sample in batch[i:i + samples_per_gpu]:
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ pad[2 * dim -
+ 1] = max_shape[dim - 1] - sample.size(-dim)
+ padded_samples.append(
+ F.pad(
+ sample.data, pad, value=sample.padding_value))
+ stacked.append(default_collate(padded_samples))
+ elif batch[i].pad_dims is None:
+ stacked.append(
+ default_collate([
+ sample.data
+ for sample in batch[i:i + samples_per_gpu]
+ ]))
+ else:
+ raise ValueError(
+ 'pad_dims should be either None or integers (1-3)')
+
+ else:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+ elif isinstance(batch[0], Sequence):
+ transposed = zip(*batch)
+ return [collate(samples, samples_per_gpu) for samples in transposed]
+ elif isinstance(batch[0], Mapping):
+ return {
+ key: collate([d[key] for d in batch], samples_per_gpu)
+ for key in batch[0]
+ }
+ else:
+ return default_collate(batch)
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/data_container.py b/src/custom_mmpkg/custom_mmcv/parallel/data_container.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/data_container.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import torch
+
+
+def assert_tensor_type(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if not isinstance(args[0].data, torch.Tensor):
+ raise AttributeError(
+ f'{args[0].__class__.__name__} has no attribute '
+ f'{func.__name__} for type {args[0].datatype}')
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class DataContainer:
+ """A container for any type of objects.
+
+ Typically tensors will be stacked in the collate function and sliced along
+ some dimension in the scatter function. This behavior has some limitations.
+ 1. All tensors have to be the same size.
+ 2. Types are limited (numpy array or Tensor).
+
+ We design `DataContainer` and `MMDataParallel` to overcome these
+ limitations. The behavior can be either of the following.
+
+ - copy to GPU, pad all tensors to the same size and stack them
+ - copy to GPU without stacking
+ - leave the objects as is and pass it to the model
+ - pad_dims specifies the number of last few dimensions to do padding
+ """
+
+ def __init__(self,
+ data,
+ stack=False,
+ padding_value=0,
+ cpu_only=False,
+ pad_dims=2):
+ self._data = data
+ self._cpu_only = cpu_only
+ self._stack = stack
+ self._padding_value = padding_value
+ assert pad_dims in [None, 1, 2, 3]
+ self._pad_dims = pad_dims
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({repr(self.data)})'
+
+ def __len__(self):
+ return len(self._data)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def datatype(self):
+ if isinstance(self.data, torch.Tensor):
+ return self.data.type()
+ else:
+ return type(self.data)
+
+ @property
+ def cpu_only(self):
+ return self._cpu_only
+
+ @property
+ def stack(self):
+ return self._stack
+
+ @property
+ def padding_value(self):
+ return self._padding_value
+
+ @property
+ def pad_dims(self):
+ return self._pad_dims
+
+ @assert_tensor_type
+ def size(self, *args, **kwargs):
+ return self.data.size(*args, **kwargs)
+
+ @assert_tensor_type
+ def dim(self):
+ return self.data.dim()
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/data_parallel.py b/src/custom_mmpkg/custom_mmcv/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/data_parallel.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import chain
+
+from torch.nn.parallel import DataParallel
+
+from .scatter_gather import scatter_kwargs
+
+
+class MMDataParallel(DataParallel):
+ """The DataParallel module that supports DataContainer.
+
+ MMDataParallel has two main differences with PyTorch DataParallel:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data during both GPU and CPU inference.
+ - It implement two more APIs ``train_step()`` and ``val_step()``.
+
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ device_ids (list[int]): Device IDS of modules to be scattered to.
+ Defaults to None when GPU is not available.
+ output_device (str | int): Device ID for output. Defaults to None.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+
+ def __init__(self, *args, dim=0, **kwargs):
+ super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
+ self.dim = dim
+
+ def forward(self, *inputs, **kwargs):
+ """Override the original forward function.
+
+ The main difference lies in the CPU inference where the data in
+ :class:`DataContainers` will still be gathered.
+ """
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module(*inputs[0], **kwargs[0])
+ else:
+ return super().forward(*inputs, **kwargs)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ 'instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ def val_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.val_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ ' instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.val_step(*inputs[0], **kwargs[0])
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/distributed.py b/src/custom_mmpkg/custom_mmcv/parallel/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1bae90f8d4078f7c52bfc565f8349f1e5c8db0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/distributed.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel.distributed import (DistributedDataParallel,
+ _find_tensors)
+
+from custom_mmpkg.custom_mmcv import print_log
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, digit_version
+from .scatter_gather import scatter_kwargs
+
+
+class MMDistributedDataParallel(DistributedDataParallel):
+ """The DDP module that supports DataContainer.
+
+ MMDDP has two main differences with PyTorch DDP:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data.
+ - It implement two APIs ``train_step()`` and ``val_step()``.
+ """
+
+ def to_kwargs(self, inputs, kwargs, device_id):
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+ # to move all tensors to device_id
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ """train_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.train_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.train_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ """val_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.val_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.val_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/distributed_deprecated.py b/src/custom_mmpkg/custom_mmcv/parallel/distributed_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..d31f7be0eb5b7f92c0d2fca6faca69152472ac27
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/distributed_deprecated.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, digit_version
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter_kwargs
+
+
+@MODULE_WRAPPERS.register_module()
+class MMDistributedDataParallel(nn.Module):
+
+ def __init__(self,
+ module,
+ dim=0,
+ broadcast_buffers=True,
+ bucket_cap_mb=25):
+ super(MMDistributedDataParallel, self).__init__()
+ self.module = module
+ self.dim = dim
+ self.broadcast_buffers = broadcast_buffers
+
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+ self._sync_params()
+
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ for tensors in _take_tensors(tensors, buffer_size):
+ flat_tensors = _flatten_dense_tensors(tensors)
+ dist.broadcast(flat_tensors, 0)
+ for tensor, synced in zip(
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
+ tensor.copy_(synced)
+
+ def _sync_params(self):
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states,
+ self.broadcast_bucket_size)
+ if self.broadcast_buffers:
+ if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) < digit_version('1.0')):
+ buffers = [b.data for b in self.module._all_buffers()]
+ else:
+ buffers = [b.data for b in self.module.buffers()]
+ if len(buffers) > 0:
+ self._dist_broadcast_coalesced(buffers,
+ self.broadcast_bucket_size)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def forward(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
+
+ def train_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ return output
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/registry.py b/src/custom_mmpkg/custom_mmcv/parallel/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a0e9f6639628c444e4682d639eabbef76114d01
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/registry.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from custom_mmpkg.custom_mmcv.utils import Registry
+
+MODULE_WRAPPERS = Registry('module wrapper')
+MODULE_WRAPPERS.register_module(module=DataParallel)
+MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/scatter_gather.py b/src/custom_mmpkg/custom_mmcv/parallel/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+
+from ._functions import Scatter
+from .data_container import DataContainer
+
+
+def scatter(inputs, target_gpus, dim=0):
+ """Scatter inputs to target gpus.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_gpus != [-1]:
+ return OrigScatter.apply(target_gpus, None, dim, obj)
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_gpus, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_gpus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_gpus]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/src/custom_mmpkg/custom_mmcv/parallel/utils.py b/src/custom_mmpkg/custom_mmcv/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/parallel/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .registry import MODULE_WRAPPERS
+
+
+def is_module_wrapper(module):
+ """Check if a module is a module wrapper.
+
+ The following 3 modules in MMCV (and their subclasses) are regarded as
+ module wrappers: DataParallel, DistributedDataParallel,
+ MMDistributedDataParallel (the deprecated version). You may add you own
+ module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: True if the input module is a module wrapper.
+ """
+ module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
+ return isinstance(module, module_wrappers)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/__init__.py b/src/custom_mmpkg/custom_mmcv/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_module import BaseModule, ModuleList, Sequential
+from .base_runner import BaseRunner
+from .builder import RUNNERS, build_runner
+from .checkpoint import (CheckpointLoader, _load_checkpoint,
+ _load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict, save_checkpoint, weights_to_cpu)
+from .default_constructor import DefaultRunnerConstructor
+from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
+ init_dist, master_only)
+from .epoch_based_runner import EpochBasedRunner, Runner
+from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
+from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, Hook, IterTimerHook,
+ LoggerHook, LrUpdaterHook, MlflowLoggerHook,
+ NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
+ SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
+ WandbLoggerHook)
+from .iter_based_runner import IterBasedRunner, IterLoader
+from .log_buffer import LogBuffer
+from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
+ DefaultOptimizerConstructor, build_optimizer,
+ build_optimizer_constructor)
+from .priority import Priority, get_priority
+from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
+
+__all__ = [
+ 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+ 'ModuleList', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/base_module.py b/src/custom_mmpkg/custom_mmcv/runner/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..362b0ae39a9e5e92b22f52918eaecc11dfde10b3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/base_module.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv.runner.dist_utils import master_only
+from custom_mmpkg.custom_mmcv.utils.logging import get_logger, logger_initialized, print_log
+
+
+class BaseModule(nn.Module, metaclass=ABCMeta):
+ """Base module for all modules in openmmlab.
+
+ ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
+ functionality of parameter initialization. Compared with
+ ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
+
+ - ``init_cfg``: the config to control the initialization.
+ - ``init_weights``: The function of parameter
+ initialization and recording initialization
+ information.
+ - ``_params_init_info``: Used to track the parameter
+ initialization information. This attribute only
+ exists during executing the ``init_weights``.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, init_cfg=None):
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
+
+ # NOTE init_cfg can be defined in different levels, but init_cfg
+ # in low levels has a higher priority.
+
+ super(BaseModule, self).__init__()
+ # define default value of init_cfg instead of hard code
+ # in init_weights() function
+ self._is_init = False
+
+ self.init_cfg = copy.deepcopy(init_cfg)
+
+ # Backward compatibility in derived classes
+ # if pretrained is not None:
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
+ # key, please consider using init_cfg')
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+
+ @property
+ def is_init(self):
+ return self._is_init
+
+ def init_weights(self):
+ """Initialize the weights."""
+
+ is_top_level_module = False
+ # check if it is top-level module
+ if not hasattr(self, '_params_init_info'):
+ # The `_params_init_info` is used to record the initialization
+ # information of the parameters
+ # the key should be the obj:`nn.Parameter` of model and the value
+ # should be a dict containing
+ # - init_info (str): The string that describes the initialization.
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
+ # which indicates whether the parameter has been modified.
+ # this attribute would be deleted after all parameters
+ # is initialized.
+ self._params_init_info = defaultdict(dict)
+ is_top_level_module = True
+
+ # Initialize the `_params_init_info`,
+ # When detecting the `tmp_mean_value` of
+ # the corresponding parameter is changed, update related
+ # initialization information
+ for name, param in self.named_parameters():
+ self._params_init_info[param][
+ 'init_info'] = f'The value is the same before and ' \
+ f'after calling `init_weights` ' \
+ f'of {self.__class__.__name__} '
+ self._params_init_info[param][
+ 'tmp_mean_value'] = param.data.mean()
+
+ # pass `params_init_info` to all submodules
+ # All submodules share the same `params_init_info`,
+ # so it will be updated when parameters are
+ # modified at any level of the model.
+ for sub_module in self.modules():
+ sub_module._params_init_info = self._params_init_info
+
+ # Get the initialized logger, if not exist,
+ # create a logger named `mmcv`
+ logger_names = list(logger_initialized.keys())
+ logger_name = logger_names[0] if logger_names else 'mmcv'
+
+ from ..cnn import initialize
+ from ..cnn.utils.weight_init import update_init_info
+ module_name = self.__class__.__name__
+ if not self._is_init:
+ if self.init_cfg:
+ print_log(
+ f'initialize {module_name} with init_cfg {self.init_cfg}',
+ logger=logger_name)
+ initialize(self, self.init_cfg)
+ if isinstance(self.init_cfg, dict):
+ # prevent the parameters of
+ # the pre-trained model
+ # from being overwritten by
+ # the `init_weights`
+ if self.init_cfg['type'] == 'Pretrained':
+ return
+
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights()
+ # users may overload the `init_weights`
+ update_init_info(
+ m,
+ init_info=f'Initialized by '
+ f'user-defined `init_weights`'
+ f' in {m.__class__.__name__} ')
+
+ self._is_init = True
+ else:
+ warnings.warn(f'init_weights of {self.__class__.__name__} has '
+ f'been called more than once.')
+
+ if is_top_level_module:
+ self._dump_init_info(logger_name)
+
+ for sub_module in self.modules():
+ del sub_module._params_init_info
+
+ @master_only
+ def _dump_init_info(self, logger_name):
+ """Dump the initialization information to a file named
+ `initialization.log.json` in workdir.
+
+ Args:
+ logger_name (str): The name of logger.
+ """
+
+ logger = get_logger(logger_name)
+
+ with_file_handler = False
+ # dump the information to the logger file if there is a `FileHandler`
+ for handler in logger.handlers:
+ if isinstance(handler, FileHandler):
+ handler.stream.write(
+ 'Name of parameter - Initialization information\n')
+ for name, param in self.named_parameters():
+ handler.stream.write(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n")
+ handler.stream.flush()
+ with_file_handler = True
+ if not with_file_handler:
+ for name, param in self.named_parameters():
+ print_log(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n ",
+ logger=logger_name)
+
+ def __repr__(self):
+ s = super().__repr__()
+ if self.init_cfg:
+ s += f'\ninit_cfg={self.init_cfg}'
+ return s
+
+
+class Sequential(BaseModule, nn.Sequential):
+ """Sequential module in openmmlab.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, *args, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.Sequential.__init__(self, *args)
+
+
+class ModuleList(BaseModule, nn.ModuleList):
+ """ModuleList in openmmlab.
+
+ Args:
+ modules (iterable, optional): an iterable of modules to add.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, modules=None, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.ModuleList.__init__(self, modules)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/base_runner.py b/src/custom_mmpkg/custom_mmcv/runner/base_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f15d71940ae558c10fcd4372d0c87f1efde93a9
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/base_runner.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+
+import torch
+from torch.optim import Optimizer
+
+import custom_mmpkg.custom_mmcv as mmcv
+from ..parallel import is_module_wrapper
+from .checkpoint import load_checkpoint
+from .dist_utils import get_dist_info
+from .hooks import HOOKS, Hook
+from .log_buffer import LogBuffer
+from .priority import Priority, get_priority
+from .utils import get_time_str
+
+
+class BaseRunner(metaclass=ABCMeta):
+ """The base class of Runner, a training helper for PyTorch.
+
+ All subclasses should implement the following APIs:
+
+ - ``run()``
+ - ``train()``
+ - ``val()``
+ - ``save_checkpoint()``
+
+ Args:
+ model (:obj:`torch.nn.Module`): The model to be run.
+ batch_processor (callable): A callable method that process a data
+ batch. The interface of this method should be
+ `batch_processor(model, data, train_mode) -> dict`
+ optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+ optimizer (in most cases) or a dict of optimizers (in models that
+ requires more than one optimizer, e.g., GAN).
+ work_dir (str, optional): The working directory to save checkpoints
+ and logs. Defaults to None.
+ logger (:obj:`logging.Logger`): Logger used during training.
+ Defaults to None. (The default value is just for backward
+ compatibility)
+ meta (dict | None): A dict records some import information such as
+ environment info and seed, which will be logged in logger hook.
+ Defaults to None.
+ max_epochs (int, optional): Total training epochs.
+ max_iters (int, optional): Total training iterations.
+ """
+
+ def __init__(self,
+ model,
+ batch_processor=None,
+ optimizer=None,
+ work_dir=None,
+ logger=None,
+ meta=None,
+ max_iters=None,
+ max_epochs=None):
+ if batch_processor is not None:
+ if not callable(batch_processor):
+ raise TypeError('batch_processor must be callable, '
+ f'but got {type(batch_processor)}')
+ warnings.warn('batch_processor is deprecated, please implement '
+ 'train_step() and val_step() in the model instead.')
+ # raise an error is `batch_processor` is not None and
+ # `model.train_step()` exists.
+ if is_module_wrapper(model):
+ _model = model.module
+ else:
+ _model = model
+ if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+ raise RuntimeError(
+ 'batch_processor and model.train_step()/model.val_step() '
+ 'cannot be both available.')
+ else:
+ assert hasattr(model, 'train_step')
+
+ # check the type of `optimizer`
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ if not isinstance(optim, Optimizer):
+ raise TypeError(
+ f'optimizer must be a dict of torch.optim.Optimizers, '
+ f'but optimizer["{name}"] is a {type(optim)}')
+ elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+ raise TypeError(
+ f'optimizer must be a torch.optim.Optimizer object '
+ f'or dict or None, but got {type(optimizer)}')
+
+ # check the type of `logger`
+ if not isinstance(logger, logging.Logger):
+ raise TypeError(f'logger must be a logging.Logger object, '
+ f'but got {type(logger)}')
+
+ # check the type of `meta`
+ if meta is not None and not isinstance(meta, dict):
+ raise TypeError(
+ f'meta must be a dict or None, but got {type(meta)}')
+
+ self.model = model
+ self.batch_processor = batch_processor
+ self.optimizer = optimizer
+ self.logger = logger
+ self.meta = meta
+ # create work_dir
+ if mmcv.is_str(work_dir):
+ self.work_dir = osp.abspath(work_dir)
+ mmcv.mkdir_or_exist(self.work_dir)
+ elif work_dir is None:
+ self.work_dir = None
+ else:
+ raise TypeError('"work_dir" must be a str or None')
+
+ # get model name from the model class
+ if hasattr(self.model, 'module'):
+ self._model_name = self.model.module.__class__.__name__
+ else:
+ self._model_name = self.model.__class__.__name__
+
+ self._rank, self._world_size = get_dist_info()
+ self.timestamp = get_time_str()
+ self.mode = None
+ self._hooks = []
+ self._epoch = 0
+ self._iter = 0
+ self._inner_iter = 0
+
+ if max_epochs is not None and max_iters is not None:
+ raise ValueError(
+ 'Only one of `max_epochs` or `max_iters` can be set.')
+
+ self._max_epochs = max_epochs
+ self._max_iters = max_iters
+ # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+ self.log_buffer = LogBuffer()
+
+ @property
+ def model_name(self):
+ """str: Name of the model, usually the module class name."""
+ return self._model_name
+
+ @property
+ def rank(self):
+ """int: Rank of current process. (distributed training)"""
+ return self._rank
+
+ @property
+ def world_size(self):
+ """int: Number of processes participating in the job.
+ (distributed training)"""
+ return self._world_size
+
+ @property
+ def hooks(self):
+ """list[:obj:`Hook`]: A list of registered hooks."""
+ return self._hooks
+
+ @property
+ def epoch(self):
+ """int: Current epoch."""
+ return self._epoch
+
+ @property
+ def iter(self):
+ """int: Current iteration."""
+ return self._iter
+
+ @property
+ def inner_iter(self):
+ """int: Iteration in an epoch."""
+ return self._inner_iter
+
+ @property
+ def max_epochs(self):
+ """int: Maximum training epochs."""
+ return self._max_epochs
+
+ @property
+ def max_iters(self):
+ """int: Maximum training iterations."""
+ return self._max_iters
+
+ @abstractmethod
+ def train(self):
+ pass
+
+ @abstractmethod
+ def val(self):
+ pass
+
+ @abstractmethod
+ def run(self, data_loaders, workflow, **kwargs):
+ pass
+
+ @abstractmethod
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl,
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ pass
+
+ def current_lr(self):
+ """Get current learning rates.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current learning rates of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+ if isinstance(self.optimizer, torch.optim.Optimizer):
+ lr = [group['lr'] for group in self.optimizer.param_groups]
+ elif isinstance(self.optimizer, dict):
+ lr = dict()
+ for name, optim in self.optimizer.items():
+ lr[name] = [group['lr'] for group in optim.param_groups]
+ else:
+ raise RuntimeError(
+ 'lr is not applicable because optimizer does not exist.')
+ return lr
+
+ def current_momentum(self):
+ """Get current momentums.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current momentums of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+
+ def _get_momentum(optimizer):
+ momentums = []
+ for group in optimizer.param_groups:
+ if 'momentum' in group.keys():
+ momentums.append(group['momentum'])
+ elif 'betas' in group.keys():
+ momentums.append(group['betas'][0])
+ else:
+ momentums.append(0)
+ return momentums
+
+ if self.optimizer is None:
+ raise RuntimeError(
+ 'momentum is not applicable because optimizer does not exist.')
+ elif isinstance(self.optimizer, torch.optim.Optimizer):
+ momentums = _get_momentum(self.optimizer)
+ elif isinstance(self.optimizer, dict):
+ momentums = dict()
+ for name, optim in self.optimizer.items():
+ momentums[name] = _get_momentum(optim)
+ return momentums
+
+ def register_hook(self, hook, priority='NORMAL'):
+ """Register a hook into the hook list.
+
+ The hook will be inserted into a priority queue, with the specified
+ priority (See :class:`Priority` for details of priorities).
+ For hooks with the same priority, they will be triggered in the same
+ order as they are registered.
+
+ Args:
+ hook (:obj:`Hook`): The hook to be registered.
+ priority (int or str or :obj:`Priority`): Hook priority.
+ Lower value means higher priority.
+ """
+ assert isinstance(hook, Hook)
+ if hasattr(hook, 'priority'):
+ raise ValueError('"priority" is a reserved attribute for hooks')
+ priority = get_priority(priority)
+ hook.priority = priority
+ # insert the hook to a sorted list
+ inserted = False
+ for i in range(len(self._hooks) - 1, -1, -1):
+ if priority >= self._hooks[i].priority:
+ self._hooks.insert(i + 1, hook)
+ inserted = True
+ break
+ if not inserted:
+ self._hooks.insert(0, hook)
+
+ def register_hook_from_cfg(self, hook_cfg):
+ """Register a hook from its cfg.
+
+ Args:
+ hook_cfg (dict): Hook config. It should have at least keys 'type'
+ and 'priority' indicating its type and priority.
+
+ Notes:
+ The specific hook class to register should not use 'type' and
+ 'priority' arguments during initialization.
+ """
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+ self.register_hook(hook, priority=priority)
+
+ def call_hook(self, fn_name):
+ """Call all hooks.
+
+ Args:
+ fn_name (str): The function name in each hook to be called, such as
+ "before_train_epoch".
+ """
+ for hook in self._hooks:
+ getattr(hook, fn_name)(self)
+
+ def get_hook_info(self):
+ # Get hooks info in each stage
+ stage_hook_map = {stage: [] for stage in Hook.stages}
+ for hook in self.hooks:
+ try:
+ priority = Priority(hook.priority).name
+ except ValueError:
+ priority = hook.priority
+ classname = hook.__class__.__name__
+ hook_info = f'({priority:<12}) {classname:<35}'
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
+
+ stage_hook_infos = []
+ for stage in Hook.stages:
+ hook_infos = stage_hook_map[stage]
+ if len(hook_infos) > 0:
+ info = f'{stage}:\n'
+ info += '\n'.join(hook_infos)
+ info += '\n -------------------- '
+ stage_hook_infos.append(info)
+ return '\n'.join(stage_hook_infos)
+
+ def load_checkpoint(self,
+ filename,
+ map_location='cpu',
+ strict=False,
+ revise_keys=[(r'^module.', '')]):
+ return load_checkpoint(
+ self.model,
+ filename,
+ map_location,
+ strict,
+ self.logger,
+ revise_keys=revise_keys)
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ if map_location == 'default':
+ if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(checkpoint)
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ if self.meta is None:
+ self.meta = {}
+ self.meta.setdefault('hook_msgs', {})
+ # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
+ self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
+
+ # Re-calculate the number of iterations when resuming
+ # models with different number of GPUs
+ if 'config' in checkpoint['meta']:
+ config = mmcv.Config.fromstring(
+ checkpoint['meta']['config'], file_format='.py')
+ previous_gpu_ids = config.get('gpu_ids', None)
+ if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
+ previous_gpu_ids) != self.world_size:
+ self._iter = int(self._iter * len(previous_gpu_ids) /
+ self.world_size)
+ self.logger.info('the iteration number is changed due to '
+ 'change of GPU number')
+
+ # resume meta information meta
+ self.meta = checkpoint['meta']
+
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+
+ def register_lr_hook(self, lr_config):
+ if lr_config is None:
+ return
+ elif isinstance(lr_config, dict):
+ assert 'policy' in lr_config
+ policy_type = lr_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of Lr updater.
+ # Since this is not applicable for `
+ # CosineAnnealingLrUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'LrUpdaterHook'
+ lr_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(lr_config, HOOKS)
+ else:
+ hook = lr_config
+ self.register_hook(hook, priority='VERY_HIGH')
+
+ def register_momentum_hook(self, momentum_config):
+ if momentum_config is None:
+ return
+ if isinstance(momentum_config, dict):
+ assert 'policy' in momentum_config
+ policy_type = momentum_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of momentum updater.
+ # Since this is not applicable for
+ # `CosineAnnealingMomentumUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'MomentumUpdaterHook'
+ momentum_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+ else:
+ hook = momentum_config
+ self.register_hook(hook, priority='HIGH')
+
+ def register_optimizer_hook(self, optimizer_config):
+ if optimizer_config is None:
+ return
+ if isinstance(optimizer_config, dict):
+ optimizer_config.setdefault('type', 'OptimizerHook')
+ hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+ else:
+ hook = optimizer_config
+ self.register_hook(hook, priority='ABOVE_NORMAL')
+
+ def register_checkpoint_hook(self, checkpoint_config):
+ if checkpoint_config is None:
+ return
+ if isinstance(checkpoint_config, dict):
+ checkpoint_config.setdefault('type', 'CheckpointHook')
+ hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+ else:
+ hook = checkpoint_config
+ self.register_hook(hook, priority='NORMAL')
+
+ def register_logger_hooks(self, log_config):
+ if log_config is None:
+ return
+ log_interval = log_config['interval']
+ for info in log_config['hooks']:
+ logger_hook = mmcv.build_from_cfg(
+ info, HOOKS, default_args=dict(interval=log_interval))
+ self.register_hook(logger_hook, priority='VERY_LOW')
+
+ def register_timer_hook(self, timer_config):
+ if timer_config is None:
+ return
+ if isinstance(timer_config, dict):
+ timer_config_ = copy.deepcopy(timer_config)
+ hook = mmcv.build_from_cfg(timer_config_, HOOKS)
+ else:
+ hook = timer_config
+ self.register_hook(hook, priority='LOW')
+
+ def register_custom_hooks(self, custom_config):
+ if custom_config is None:
+ return
+
+ if not isinstance(custom_config, list):
+ custom_config = [custom_config]
+
+ for item in custom_config:
+ if isinstance(item, dict):
+ self.register_hook_from_cfg(item)
+ else:
+ self.register_hook(item, priority='NORMAL')
+
+ def register_profiler_hook(self, profiler_config):
+ if profiler_config is None:
+ return
+ if isinstance(profiler_config, dict):
+ profiler_config.setdefault('type', 'ProfilerHook')
+ hook = mmcv.build_from_cfg(profiler_config, HOOKS)
+ else:
+ hook = profiler_config
+ self.register_hook(hook)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ timer_config=dict(type='IterTimerHook'),
+ custom_hooks_config=None):
+ """Register default and custom hooks for training.
+
+ Default and custom hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ self.register_lr_hook(lr_config)
+ self.register_momentum_hook(momentum_config)
+ self.register_optimizer_hook(optimizer_config)
+ self.register_checkpoint_hook(checkpoint_config)
+ self.register_timer_hook(timer_config)
+ self.register_logger_hooks(log_config)
+ self.register_custom_hooks(custom_hooks_config)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/builder.py b/src/custom_mmpkg/custom_mmcv/runner/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/builder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from ..utils import Registry
+
+RUNNERS = Registry('runner')
+RUNNER_BUILDERS = Registry('runner builder')
+
+
+def build_runner_constructor(cfg):
+ return RUNNER_BUILDERS.build(cfg)
+
+
+def build_runner(cfg, default_args=None):
+ runner_cfg = copy.deepcopy(cfg)
+ constructor_type = runner_cfg.pop('constructor',
+ 'DefaultRunnerConstructor')
+ runner_constructor = build_runner_constructor(
+ dict(
+ type=constructor_type,
+ runner_cfg=runner_cfg,
+ default_args=default_args))
+ runner = runner_constructor()
+ return runner
diff --git a/src/custom_mmpkg/custom_mmcv/runner/checkpoint.py b/src/custom_mmpkg/custom_mmcv/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..da1481088ceb805007b3f1a7cad8bd528d5853f6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/checkpoint.py
@@ -0,0 +1,707 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import re
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+
+import custom_mmpkg.custom_mmcv as mmcv
+from ..fileio import FileClient
+from ..fileio import load as load_file
+from ..parallel import is_module_wrapper
+from ..utils import mkdir_or_exist
+from .dist_utils import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+class CheckpointLoader:
+ """A general checkpoint loader to manage all schemes."""
+
+ _schemes = {}
+
+ @classmethod
+ def _register_scheme(cls, prefixes, loader, force=False):
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if (prefix not in cls._schemes) or force:
+ cls._schemes[prefix] = loader
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a loader backend, '
+ 'add "force=True" if you want to override it')
+ # sort, longer prefixes take priority
+ cls._schemes = OrderedDict(
+ sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
+
+ @classmethod
+ def register_scheme(cls, prefixes, loader=None, force=False):
+ """Register a loader to CheckpointLoader.
+
+ This method can be used as a normal class method or a decorator.
+
+ Args:
+ prefixes (str or list[str] or tuple[str]):
+ The prefix of the registered loader.
+ loader (function, optional): The loader function to be registered.
+ When this method is used as a decorator, loader is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the loader
+ if the prefix has already been registered. Defaults to False.
+ """
+
+ if loader is not None:
+ cls._register_scheme(prefixes, loader, force=force)
+ return
+
+ def _register(loader_cls):
+ cls._register_scheme(prefixes, loader_cls, force=force)
+ return loader_cls
+
+ return _register
+
+ @classmethod
+ def _get_checkpoint_loader(cls, path):
+ """Finds a loader that supports the given path. Falls back to the local
+ loader if no other loader is found.
+
+ Args:
+ path (str): checkpoint path
+
+ Returns:
+ loader (function): checkpoint loader
+ """
+
+ for p in cls._schemes:
+ if path.startswith(p):
+ return cls._schemes[p]
+
+ @classmethod
+ def load_checkpoint(cls, filename, map_location=None, logger=None):
+ """load checkpoint through URL scheme path.
+
+ Args:
+ filename (str): checkpoint file name with given prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ logger (:mod:`logging.Logger`, optional): The logger for message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint_loader = cls._get_checkpoint_loader(filename)
+ class_name = checkpoint_loader.__name__
+ mmcv.print_log(
+ f'load checkpoint from {class_name[10:]} path: {filename}', logger)
+ return checkpoint_loader(filename, map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes='')
+def load_from_local(filename, map_location):
+ """load checkpoint by local file path.
+
+ Args:
+ filename (str): local checkpoint file path
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
+def load_from_http(filename, map_location=None, model_dir=None):
+ """load checkpoint through HTTP or HTTPS scheme path. In distributed
+ setting, this function only download checkpoint at local rank 0.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ model_dir (string, optional): directory in which to save the object,
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='pavi://')
+def load_from_pavi(filename, map_location=None):
+ """load checkpoint through the file path prefixed with pavi. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with pavi prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ assert filename.startswith('pavi://'), \
+ f'Expected filename startswith `pavi://`, but get {filename}'
+ model_path = filename[7:]
+
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='s3://')
+def load_from_ceph(filename, map_location=None, backend='petrel'):
+ """load checkpoint through the file path prefixed with s3. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with s3 prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ backend (str, optional): The storage backend type. Options are 'ceph',
+ 'petrel'. Default: 'petrel'.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ allowed_backends = ['ceph', 'petrel']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+
+ if backend == 'ceph':
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+
+ # CephClient and PetrelBackend have the same prefix 's3://' and the latter
+ # will be chosen as default. If PetrelBackend can not be instantiated
+ # successfully, the CephClient will be chosen.
+ try:
+ file_client = FileClient(backend=backend)
+ except ImportError:
+ allowed_backends.remove(backend)
+ file_client = FileClient(backend=allowed_backends[0])
+
+ with io.BytesIO(file_client.get(filename)) as buffer:
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
+def load_from_torchvision(filename, map_location=None):
+ """load checkpoint through the file path prefixed with modelzoo or
+ torchvision.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_torchvision_models()
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_name = filename[11:]
+ else:
+ model_name = filename[14:]
+ return load_from_http(model_urls[model_name], map_location=map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
+def load_from_openmmlab(filename, map_location=None):
+ """load checkpoint through the file path prefixed with open-mmlab or
+ openmmlab.
+
+ Args:
+ filename (str): checkpoint file path with open-mmlab or
+ openmmlab prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_external_models()
+ prefix_str = 'open-mmlab://'
+ if filename.startswith(prefix_str):
+ model_name = filename[13:]
+ else:
+ model_name = filename[12:]
+ prefix_str = 'openmmlab://'
+
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
+ f'of {prefix_str}{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_from_http(model_url, map_location=map_location)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='mmcls://')
+def load_from_mmcls(filename, map_location=None):
+ """load checkpoint through the file path prefixed with mmcls.
+
+ Args:
+ filename (str): checkpoint file path with mmcls prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_from_http(
+ model_urls[model_name], map_location=map_location)
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ return checkpoint
+
+
+def _load_checkpoint(filename, map_location=None, logger=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None.
+ logger (:mod:`logging.Logger`, optional): The logger for error message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ return CheckpointLoader.load_checkpoint(filename, map_location, logger)
+
+
+def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
+ """Load partial pretrained model with specific prefix.
+
+ Args:
+ prefix (str): The prefix of sub-module.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint = _load_checkpoint(filename, map_location=map_location)
+
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if not prefix.endswith('.'):
+ prefix += '.'
+ prefix_len = len(prefix)
+
+ state_dict = {
+ k[prefix_len:]: v
+ for k, v in state_dict.items() if k.startswith(prefix)
+ }
+
+ assert state_dict, f'{prefix} is not in the pretrained model'
+ return state_dict
+
+
+def load_checkpoint(model,
+ filename,
+ map_location=None,
+ strict=False,
+ logger=None,
+ revise_keys=[(r'^module\.', '')]):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ revise_keys (list): A list of customized keywords to modify the
+ state_dict in checkpoint. Each item is a (pattern, replacement)
+ pair of the regular expression operations. Default: strip
+ the prefix 'module.' by [(r'^module\\.', '')].
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location, logger)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ # strip prefix of state_dict
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
+ for p, r in revise_keys:
+ state_dict = OrderedDict(
+ {re.sub(p, r, k): v
+ for k, v in state_dict.items()})
+ # Keep metadata in state_dict
+ state_dict._metadata = metadata
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ # Keep metadata in state_dict
+ state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model,
+ filename,
+ optimizer=None,
+ meta=None,
+ file_client_args=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ if file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" if filename starts with'
+ f'"pavi://", but got {file_client_args}')
+ try:
+ from pavi import modelcloud
+ from pavi import exception
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except exception.NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with io.BytesIO() as f:
+ torch.save(checkpoint, f)
+ file_client.put(f.getvalue(), filename)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/default_constructor.py b/src/custom_mmpkg/custom_mmcv/runner/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed2e0c83e19133ce3873ea092c1a872ca254bbf
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/default_constructor.py
@@ -0,0 +1,44 @@
+from .builder import RUNNER_BUILDERS, RUNNERS
+
+
+@RUNNER_BUILDERS.register_module()
+class DefaultRunnerConstructor:
+ """Default constructor for runners.
+
+ Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
+ For example, We can inject some new properties and functions for `Runner`.
+
+ Example:
+ >>> from custom_mmpkg.custom_mmcv.runner import RUNNER_BUILDERS, build_runner
+ >>> # Define a new RunnerReconstructor
+ >>> @RUNNER_BUILDERS.register_module()
+ >>> class MyRunnerConstructor:
+ ... def __init__(self, runner_cfg, default_args=None):
+ ... if not isinstance(runner_cfg, dict):
+ ... raise TypeError('runner_cfg should be a dict',
+ ... f'but got {type(runner_cfg)}')
+ ... self.runner_cfg = runner_cfg
+ ... self.default_args = default_args
+ ...
+ ... def __call__(self):
+ ... runner = RUNNERS.build(self.runner_cfg,
+ ... default_args=self.default_args)
+ ... # Add new properties for existing runner
+ ... runner.my_name = 'my_runner'
+ ... runner.my_function = lambda self: print(self.my_name)
+ ... ...
+ >>> # build your runner
+ >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
+ ... constructor='MyRunnerConstructor')
+ >>> runner = build_runner(runner_cfg)
+ """
+
+ def __init__(self, runner_cfg, default_args=None):
+ if not isinstance(runner_cfg, dict):
+ raise TypeError('runner_cfg should be a dict',
+ f'but got {type(runner_cfg)}')
+ self.runner_cfg = runner_cfg
+ self.default_args = default_args
+
+ def __call__(self):
+ return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/dist_utils.py b/src/custom_mmpkg/custom_mmcv/runner/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/dist_utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import os
+import subprocess
+from collections import OrderedDict
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce parameters.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters or buffers of a
+ model.
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ params = [param.data for param in params]
+ if coalesce:
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
+ else:
+ for tensor in params:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/epoch_based_runner.py b/src/custom_mmpkg/custom_mmcv/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..46b96618daec0941513cc0188edfa45e4c42dfe2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/epoch_based_runner.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+
+import custom_mmpkg.custom_mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .utils import get_host_info
+
+
+@RUNNERS.register_module()
+class EpochBasedRunner(BaseRunner):
+ """Epoch-based Runner.
+
+ This runner train models epoch by epoch.
+ """
+
+ def run_iter(self, data_batch, train_mode, **kwargs):
+ if self.batch_processor is not None:
+ outputs = self.batch_processor(
+ self.model, data_batch, train_mode=train_mode, **kwargs)
+ elif train_mode:
+ outputs = self.model.train_step(data_batch, self.optimizer,
+ **kwargs)
+ else:
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('"batch_processor()" or "model.train_step()"'
+ 'and "model.val_step()" must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._max_iters = self._max_epochs * len(self.data_loader)
+ self.call_hook('before_train_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_train_iter')
+ self.run_iter(data_batch, train_mode=True, **kwargs)
+ self.call_hook('after_train_iter')
+ self._iter += 1
+
+ self.call_hook('after_train_epoch')
+ self._epoch += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ self.call_hook('before_val_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_val_iter')
+ self.run_iter(data_batch, train_mode=False)
+ self.call_hook('after_val_iter')
+
+ self.call_hook('after_val_epoch')
+
+ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
+ running order and epochs. E.g, [('train', 2), ('val', 1)] means
+ running 2 epochs for training and 1 epoch for validation,
+ iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_epochs is not None:
+ warnings.warn(
+ 'setting max_epochs in run is deprecated, '
+ 'please set max_epochs in runner_config', DeprecationWarning)
+ self._max_epochs = max_epochs
+
+ assert self._max_epochs is not None, (
+ 'max_epochs must be specified during instantiation')
+
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if mode == 'train':
+ self._max_iters = self._max_epochs * len(data_loaders[i])
+ break
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d epochs', workflow,
+ self._max_epochs)
+ self.call_hook('before_run')
+
+ while self.epoch < self._max_epochs:
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if isinstance(mode, str): # self.train()
+ if not hasattr(self, mode):
+ raise ValueError(
+ f'runner has no method named "{mode}" to run an '
+ 'epoch')
+ epoch_runner = getattr(self, mode)
+ else:
+ raise TypeError(
+ 'mode in workflow must be a str, but got {}'.format(
+ type(mode)))
+
+ for _ in range(epochs):
+ if mode == 'train' and self.epoch >= self._max_epochs:
+ break
+ epoch_runner(data_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_run')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='epoch_{}.pth',
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ """Save the checkpoint.
+
+ Args:
+ out_dir (str): The directory that checkpoints are saved.
+ filename_tmpl (str, optional): The checkpoint filename template,
+ which contains a placeholder for the epoch number.
+ Defaults to 'epoch_{}.pth'.
+ save_optimizer (bool, optional): Whether to save the optimizer to
+ the checkpoint. Defaults to True.
+ meta (dict, optional): The meta information to be saved in the
+ checkpoint. Defaults to None.
+ create_symlink (bool, optional): Whether to create a symlink
+ "latest.pth" to point to the latest checkpoint.
+ Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+
+@RUNNERS.register_module()
+class Runner(EpochBasedRunner):
+ """Deprecated name of EpochBasedRunner."""
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Runner was deprecated, please use EpochBasedRunner instead')
+ super().__init__(*args, **kwargs)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/fp16_utils.py b/src/custom_mmpkg/custom_mmcv/runner/fp16_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..752e1b23fb971a56f72ea6dfee36166670221e93
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/fp16_utils.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import warnings
+from collections import abc
+from inspect import getfullargspec
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, digit_version
+from .dist_utils import allreduce_grads as _allreduce_grads
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
+ # manually, so the behavior may not be consistent with real amp.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+
+
+def cast_tensor_type(inputs, src_type, dst_type):
+ """Recursively convert Tensor in inputs from src_type to dst_type.
+
+ Args:
+ inputs: Inputs that to be casted.
+ src_type (torch.dtype): Source type..
+ dst_type (torch.dtype): Destination type.
+
+ Returns:
+ The same type with inputs, but all contained Tensors have been cast.
+ """
+ if isinstance(inputs, nn.Module):
+ return inputs
+ elif isinstance(inputs, torch.Tensor):
+ return inputs.to(dst_type)
+ elif isinstance(inputs, str):
+ return inputs
+ elif isinstance(inputs, np.ndarray):
+ return inputs
+ elif isinstance(inputs, abc.Mapping):
+ return type(inputs)({
+ k: cast_tensor_type(v, src_type, dst_type)
+ for k, v in inputs.items()
+ })
+ elif isinstance(inputs, abc.Iterable):
+ return type(inputs)(
+ cast_tensor_type(item, src_type, dst_type) for item in inputs)
+ else:
+ return inputs
+
+
+def auto_fp16(apply_to=None, out_fp32=False):
+ """Decorator to enable fp16 training automatically.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If inputs arguments are fp32 tensors, they will
+ be converted to fp16 automatically. Arguments other than fp32 tensors are
+ ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp32 (bool): Whether to convert the output back to fp32.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp16
+ >>> @auto_fp16()
+ >>> def forward(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp16
+ >>> @auto_fp16(apply_to=('pred', ))
+ >>> def do_something(self, pred, others):
+ >>> pass
+ """
+
+ def auto_fp16_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@auto_fp16 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ # NOTE: default args are not taken into consideration
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.float, torch.half))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = {}
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.float, torch.half)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=True):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp32:
+ output = cast_tensor_type(output, torch.half, torch.float)
+ return output
+
+ return new_func
+
+ return auto_fp16_wrapper
+
+
+def force_fp32(apply_to=None, out_fp16=False):
+ """Decorator to convert input arguments to fp32 in force.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If there are some inputs that must be processed
+ in fp32 mode, then this decorator can handle it. If inputs arguments are
+ fp16 tensors, they will be converted to fp32 automatically. Arguments other
+ than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
+ torch.cuda.amp is used as the backend, otherwise, original mmcv
+ implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp16 (bool): Whether to convert the output back to fp16.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp32
+ >>> @force_fp32()
+ >>> def loss(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp32
+ >>> @force_fp32(apply_to=('pred', ))
+ >>> def post_process(self, pred, others):
+ >>> pass
+ """
+
+ def force_fp32_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@force_fp32 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.half, torch.float))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = dict()
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.half, torch.float)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=False):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp16:
+ output = cast_tensor_type(output, torch.float, torch.half)
+ return output
+
+ return new_func
+
+ return force_fp32_wrapper
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ warnings.warning(
+ '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
+ 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
+ _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
+
+
+def wrap_fp16_model(model):
+ """Wrap the FP32 model to FP16.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ For PyTorch >= 1.6, this function will
+ 1. Set fp16 flag inside the model to True.
+
+ Otherwise:
+ 1. Convert FP32 model to FP16.
+ 2. Remain some necessary layers to be FP32, e.g., normalization layers.
+ 3. Set `fp16_enabled` flag inside the model to True.
+
+ Args:
+ model (nn.Module): Model in FP32.
+ """
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
+ # convert model to fp16
+ model.half()
+ # patch the normalization layers to make it work in fp32 mode
+ patch_norm_fp32(model)
+ # set `fp16_enabled` flag
+ for m in model.modules():
+ if hasattr(m, 'fp16_enabled'):
+ m.fp16_enabled = True
+
+
+def patch_norm_fp32(module):
+ """Recursively convert normalization layers from FP16 to FP32.
+
+ Args:
+ module (nn.Module): The modules to be converted in FP16.
+
+ Returns:
+ nn.Module: The converted module, the normalization layers have been
+ converted to FP32.
+ """
+ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+ module.float()
+ if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
+ module.forward = patch_forward_method(module.forward, torch.half,
+ torch.float)
+ for child in module.children():
+ patch_norm_fp32(child)
+ return module
+
+
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+ """Patch the forward method of a module.
+
+ Args:
+ func (callable): The original forward method.
+ src_type (torch.dtype): Type of input arguments to be converted from.
+ dst_type (torch.dtype): Type of input arguments to be converted to.
+ convert_output (bool): Whether to convert the output back to src_type.
+
+ Returns:
+ callable: The patched forward method.
+ """
+
+ def new_forward(*args, **kwargs):
+ output = func(*cast_tensor_type(args, src_type, dst_type),
+ **cast_tensor_type(kwargs, src_type, dst_type))
+ if convert_output:
+ output = cast_tensor_type(output, dst_type, src_type)
+ return output
+
+ return new_forward
+
+
+class LossScaler:
+ """Class that manages loss scaling in mixed precision training which
+ supports both dynamic or static mode.
+
+ The implementation refers to
+ https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
+ Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
+ It's important to understand how :class:`LossScaler` operates.
+ Loss scaling is designed to combat the problem of underflowing
+ gradients encountered at long times when training fp16 networks.
+ Dynamic loss scaling begins by attempting a very high loss
+ scale. Ironically, this may result in OVERflowing gradients.
+ If overflowing gradients are encountered, :class:`FP16_Optimizer` then
+ skips the update step for this particular iteration/minibatch,
+ and :class:`LossScaler` adjusts the loss scale to a lower value.
+ If a certain number of iterations occur without overflowing gradients
+ detected,:class:`LossScaler` increases the loss scale once more.
+ In this way :class:`LossScaler` attempts to "ride the edge" of always
+ using the highest loss scale possible without incurring overflow.
+
+ Args:
+ init_scale (float): Initial loss scale value, default: 2**32.
+ scale_factor (float): Factor used when adjusting the loss scale.
+ Default: 2.
+ mode (str): Loss scaling mode. 'dynamic' or 'static'
+ scale_window (int): Number of consecutive iterations without an
+ overflow to wait before increasing the loss scale. Default: 1000.
+ """
+
+ def __init__(self,
+ init_scale=2**32,
+ mode='dynamic',
+ scale_factor=2.,
+ scale_window=1000):
+ self.cur_scale = init_scale
+ self.cur_iter = 0
+ assert mode in ('dynamic',
+ 'static'), 'mode can only be dynamic or static'
+ self.mode = mode
+ self.last_overflow_iter = -1
+ self.scale_factor = scale_factor
+ self.scale_window = scale_window
+
+ def has_overflow(self, params):
+ """Check if params contain overflow."""
+ if self.mode != 'dynamic':
+ return False
+ for p in params:
+ if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
+ return True
+ return False
+
+ def _has_inf_or_nan(x):
+ """Check if params contain NaN."""
+ try:
+ cpu_sum = float(x.float().sum())
+ except RuntimeError as instance:
+ if 'value cannot be converted' not in instance.args[0]:
+ raise
+ return True
+ else:
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') \
+ or cpu_sum != cpu_sum:
+ return True
+ return False
+
+ def update_scale(self, overflow):
+ """update the current loss scale value when overflow happens."""
+ if self.mode != 'dynamic':
+ return
+ if overflow:
+ self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
+ self.last_overflow_iter = self.cur_iter
+ else:
+ if (self.cur_iter - self.last_overflow_iter) % \
+ self.scale_window == 0:
+ self.cur_scale *= self.scale_factor
+ self.cur_iter += 1
+
+ def state_dict(self):
+ """Returns the state of the scaler as a :class:`dict`."""
+ return dict(
+ cur_scale=self.cur_scale,
+ cur_iter=self.cur_iter,
+ mode=self.mode,
+ last_overflow_iter=self.last_overflow_iter,
+ scale_factor=self.scale_factor,
+ scale_window=self.scale_window)
+
+ def load_state_dict(self, state_dict):
+ """Loads the loss_scaler state dict.
+
+ Args:
+ state_dict (dict): scaler state.
+ """
+ self.cur_scale = state_dict['cur_scale']
+ self.cur_iter = state_dict['cur_iter']
+ self.mode = state_dict['mode']
+ self.last_overflow_iter = state_dict['last_overflow_iter']
+ self.scale_factor = state_dict['scale_factor']
+ self.scale_window = state_dict['scale_window']
+
+ @property
+ def loss_scale(self):
+ return self.cur_scale
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/__init__.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .ema import EMAHook
+from .evaluation import DistEvalHook, EvalHook
+from .hook import HOOKS, Hook
+from .iter_timer import IterTimerHook
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+ TextLoggerHook, WandbLoggerHook)
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .momentum_updater import MomentumUpdaterHook
+from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, OptimizerHook)
+from .profiler import ProfilerHook
+from .sampler_seed import DistSamplerSeedHook
+from .sync_buffer import SyncBuffersHook
+
+__all__ = [
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
+ 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+ 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/checkpoint.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a1b688bcbd9877423ba3930a81093464aed34f6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/checkpoint.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+from custom_mmpkg.custom_mmcv.fileio import FileClient
+from ..dist_utils import allreduce_params, master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class CheckpointHook(Hook):
+ """Save checkpoints periodically.
+
+ Args:
+ interval (int): The saving period. If ``by_epoch=True``, interval
+ indicates epochs, otherwise it indicates iterations.
+ Default: -1, which means "never".
+ by_epoch (bool): Saving checkpoints by epoch or by iteration.
+ Default: True.
+ save_optimizer (bool): Whether to save optimizer state_dict in the
+ checkpoint. It is usually used for resuming experiments.
+ Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, ``runner.work_dir`` will be used by default. If
+ specified, the ``out_dir`` will be the concatenation of ``out_dir``
+ and the last level directory of ``runner.work_dir``.
+ `Changed in version 1.3.16.`
+ max_keep_ckpts (int, optional): The maximum checkpoints to keep.
+ In some cases we want only the latest few checkpoints and would
+ like to delete old ones to save the disk space.
+ Default: -1, which means unlimited.
+ save_last (bool, optional): Whether to force the last checkpoint to be
+ saved regardless of interval. Default: True.
+ sync_buffer (bool, optional): Whether to synchronize buffers in
+ different gpus. Default: False.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+
+ .. warning::
+ Before v1.3.16, the ``out_dir`` argument indicates the path where the
+ checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
+ root directory and the final path to save checkpoint is the
+ concatenation of ``out_dir`` and the last level directory of
+ ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
+ and the value of ``runner.work_dir`` is "/path/of/B", then the final
+ path will be "/path/of/A/B".
+ """
+
+ def __init__(self,
+ interval=-1,
+ by_epoch=True,
+ save_optimizer=True,
+ out_dir=None,
+ max_keep_ckpts=-1,
+ save_last=True,
+ sync_buffer=False,
+ file_client_args=None,
+ **kwargs):
+ self.interval = interval
+ self.by_epoch = by_epoch
+ self.save_optimizer = save_optimizer
+ self.out_dir = out_dir
+ self.max_keep_ckpts = max_keep_ckpts
+ self.save_last = save_last
+ self.args = kwargs
+ self.sync_buffer = sync_buffer
+ self.file_client_args = file_client_args
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+
+ runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
+ f'{self.file_client.name}.'))
+
+ # disable the create_symlink option because some file backends do not
+ # allow to create a symlink
+ if 'create_symlink' in self.args:
+ if self.args[
+ 'create_symlink'] and not self.file_client.allow_symlink:
+ self.args['create_symlink'] = False
+ warnings.warn(
+ ('create_symlink is set as True by the user but is changed'
+ 'to be False because creating symbolic link is not '
+ f'allowed in {self.file_client.name}'))
+ else:
+ self.args['create_symlink'] = self.file_client.allow_symlink
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` epochs
+ # 2. reach the last epoch of training
+ if self.every_n_epochs(
+ runner, self.interval) or (self.save_last
+ and self.is_last_epoch(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.epoch + 1} epochs')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
+
+ @master_only
+ def _save_checkpoint(self, runner):
+ """Save the current checkpoint and delete unwanted checkpoint."""
+ runner.save_checkpoint(
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args)
+ if runner.meta is not None:
+ if self.by_epoch:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+ else:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+ runner.meta.setdefault('hook_msgs', dict())
+ runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
+ self.out_dir, cur_ckpt_filename)
+ # remove other checkpoints
+ if self.max_keep_ckpts > 0:
+ if self.by_epoch:
+ name = 'epoch_{}.pth'
+ current_ckpt = runner.epoch + 1
+ else:
+ name = 'iter_{}.pth'
+ current_ckpt = runner.iter + 1
+ redundant_ckpts = range(
+ current_ckpt - self.max_keep_ckpts * self.interval, 0,
+ -self.interval)
+ filename_tmpl = self.args.get('filename_tmpl', name)
+ for _step in redundant_ckpts:
+ ckpt_path = self.file_client.join_path(
+ self.out_dir, filename_tmpl.format(_step))
+ if self.file_client.isfile(ckpt_path):
+ self.file_client.remove(ckpt_path)
+ else:
+ break
+
+ def after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` iterations
+ # 2. reach the last iteration of training
+ if self.every_n_iters(
+ runner, self.interval) or (self.save_last
+ and self.is_last_iter(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.iter + 1} iterations')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/closure.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/closure.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ClosureHook(Hook):
+
+ def __init__(self, fn_name, fn):
+ assert hasattr(self, fn_name)
+ assert callable(fn)
+ setattr(self, fn_name, fn)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/ema.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/ema.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...parallel import is_module_wrapper
+from ..hooks.hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EMAHook(Hook):
+ r"""Exponential Moving Average Hook.
+
+ Use Exponential Moving Average on all parameters of model in training
+ process. All parameters have a ema backup, which update by the formula
+ as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
+
+ .. math::
+
+ \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
+ \text{Xema\_{t}} + \text{momentum} \times X_t
+
+ Args:
+ momentum (float): The momentum used for updating ema parameter.
+ Defaults to 0.0002.
+ interval (int): Update ema parameter every interval iteration.
+ Defaults to 1.
+ warm_up (int): During first warm_up steps, we may use smaller momentum
+ to update ema parameters more slowly. Defaults to 100.
+ resume_from (str): The checkpoint path. Defaults to None.
+ """
+
+ def __init__(self,
+ momentum=0.0002,
+ interval=1,
+ warm_up=100,
+ resume_from=None):
+ assert isinstance(interval, int) and interval > 0
+ self.warm_up = warm_up
+ self.interval = interval
+ assert momentum > 0 and momentum < 1
+ self.momentum = momentum**interval
+ self.checkpoint = resume_from
+
+ def before_run(self, runner):
+ """To resume model with it's ema parameters more friendly.
+
+ Register ema parameter as ``named_buffer`` to model
+ """
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ self.param_ema_buffer = {}
+ self.model_parameters = dict(model.named_parameters(recurse=True))
+ for name, value in self.model_parameters.items():
+ # "." is not allowed in module's buffer name
+ buffer_name = f"ema_{name.replace('.', '_')}"
+ self.param_ema_buffer[name] = buffer_name
+ model.register_buffer(buffer_name, value.data.clone())
+ self.model_buffers = dict(model.named_buffers(recurse=True))
+ if self.checkpoint is not None:
+ runner.resume(self.checkpoint)
+
+ def after_train_iter(self, runner):
+ """Update ema parameter every self.interval iterations."""
+ curr_step = runner.iter
+ # We warm up the momentum considering the instability at beginning
+ momentum = min(self.momentum,
+ (1 + curr_step) / (self.warm_up + curr_step))
+ if curr_step % self.interval != 0:
+ return
+ for name, parameter in self.model_parameters.items():
+ buffer_name = self.param_ema_buffer[name]
+ buffer_parameter = self.model_buffers[buffer_name]
+ buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
+
+ def after_train_epoch(self, runner):
+ """We load parameter values from ema backup to model before the
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def before_train_epoch(self, runner):
+ """We recover model's parameter from ema backup after last epoch's
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def _swap_ema_parameters(self):
+ """Swap the parameter of model with parameter in ema_buffer."""
+ for name, value in self.model_parameters.items():
+ temp = value.data.clone()
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+ value.data.copy_(ema_buffer.data)
+ ema_buffer.data.copy_(temp)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/evaluation.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d76699d3d2d297539cdd49e1fe0626c379ec26f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/evaluation.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from math import inf
+
+import torch.distributed as dist
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+
+from custom_mmpkg.custom_mmcv.fileio import FileClient
+from custom_mmpkg.custom_mmcv.utils import is_seq_of
+from .hook import Hook
+from .logger import LoggerHook
+
+
+class EvalHook(Hook):
+ """Non-Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in non-distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader, and return the test results. If ``None``, the default
+ test function ``mmcv.engine.single_gpu_test`` will be used.
+ (default: ``None``)
+ greater_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'greater' comparison rule. If ``None``,
+ _default_greater_keys will be used. (default: ``None``)
+ less_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
+ will be used. (default: ``None``)
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ `New in version 1.3.16.`
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ `New in version 1.3.16.`
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+
+ Notes:
+ If new arguments are added for EvalHook, tools/test.py,
+ tools/eval_metric.py may be affected.
+ """
+
+ # Since the key for determine greater or less is related to the downstream
+ # tasks, downstream repos may need to overwrite the following inner
+ # variable accordingly.
+
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+ init_value_map = {'greater': -inf, 'less': inf}
+ _default_greater_keys = [
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+ 'mAcc', 'aAcc'
+ ]
+ _default_less_keys = ['loss']
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
+ f'but got {type(dataloader)}')
+
+ if interval <= 0:
+ raise ValueError(f'interval must be a positive number, '
+ f'but got {interval}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
+
+ if start is not None and start < 0:
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
+ f'than 0')
+
+ self.dataloader = dataloader
+ self.interval = interval
+ self.start = start
+ self.by_epoch = by_epoch
+
+ assert isinstance(save_best, str) or save_best is None, \
+ '""save_best"" should be a str or None ' \
+ f'rather than {type(save_best)}'
+ self.save_best = save_best
+ self.eval_kwargs = eval_kwargs
+ self.initial_flag = True
+
+ if test_fn is None:
+ from custom_mmpkg.custom_mmcv.engine import single_gpu_test
+ self.test_fn = single_gpu_test
+ else:
+ self.test_fn = test_fn
+
+ if greater_keys is None:
+ self.greater_keys = self._default_greater_keys
+ else:
+ if not isinstance(greater_keys, (list, tuple)):
+ greater_keys = (greater_keys, )
+ assert is_seq_of(greater_keys, str)
+ self.greater_keys = greater_keys
+
+ if less_keys is None:
+ self.less_keys = self._default_less_keys
+ else:
+ if not isinstance(less_keys, (list, tuple)):
+ less_keys = (less_keys, )
+ assert is_seq_of(less_keys, str)
+ self.less_keys = less_keys
+
+ if self.save_best is not None:
+ self.best_ckpt_path = None
+ self._init_rule(rule, self.save_best)
+
+ self.out_dir = out_dir
+ self.file_client_args = file_client_args
+
+ def _init_rule(self, rule, key_indicator):
+ """Initialize rule, key_indicator, comparison_func, and best score.
+
+ Here is the rule to determine which rule is used for key indicator
+ when the rule is not specific (note that the key indicator matching
+ is case-insensitive):
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
+ specified as 'greater'.
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
+ specified as 'less'.
+ 3. Or if the key indicator is equal to the substring in any one item
+ in ``self.greater_keys``, the rule will be specified as 'greater'.
+ 4. Or if the key indicator is equal to the substring in any one item
+ in ``self.less_keys``, the rule will be specified as 'less'.
+
+ Args:
+ rule (str | None): Comparison rule for best score.
+ key_indicator (str | None): Key indicator to determine the
+ comparison rule.
+ """
+ if rule not in self.rule_map and rule is not None:
+ raise KeyError(f'rule must be greater, less or None, '
+ f'but got {rule}.')
+
+ if rule is None:
+ if key_indicator != 'auto':
+ # `_lc` here means we use the lower case of keys for
+ # case-insensitive matching
+ key_indicator_lc = key_indicator.lower()
+ greater_keys = [key.lower() for key in self.greater_keys]
+ less_keys = [key.lower() for key in self.less_keys]
+
+ if key_indicator_lc in greater_keys:
+ rule = 'greater'
+ elif key_indicator_lc in less_keys:
+ rule = 'less'
+ elif any(key in key_indicator_lc for key in greater_keys):
+ rule = 'greater'
+ elif any(key in key_indicator_lc for key in less_keys):
+ rule = 'less'
+ else:
+ raise ValueError(f'Cannot infer the rule for key '
+ f'{key_indicator}, thus a specific rule '
+ f'must be specified.')
+ self.rule = rule
+ self.key_indicator = key_indicator
+ if self.rule is not None:
+ self.compare_func = self.rule_map[self.rule]
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'The best checkpoint will be saved to {self.out_dir} by '
+ f'{self.file_client.name}'))
+
+ if self.save_best is not None:
+ if runner.meta is None:
+ warnings.warn('runner.meta is None. Creating an empty one.')
+ runner.meta = dict()
+ runner.meta.setdefault('hook_msgs', dict())
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
+ 'best_ckpt', None)
+
+ def before_train_iter(self, runner):
+ """Evaluate the model only at the start of training by iteration."""
+ if self.by_epoch or not self.initial_flag:
+ return
+ if self.start is not None and runner.iter >= self.start:
+ self.after_train_iter(runner)
+ self.initial_flag = False
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ if not (self.by_epoch and self.initial_flag):
+ return
+ if self.start is not None and runner.epoch >= self.start:
+ self.after_train_epoch(runner)
+ self.initial_flag = False
+
+ def after_train_iter(self, runner):
+ """Called after every training iter to evaluate the results."""
+ if not self.by_epoch and self._should_evaluate(runner):
+ # Because the priority of EvalHook is higher than LoggerHook, the
+ # training log and the evaluating log are mixed. Therefore,
+ # we need to dump the training log and clear it before evaluating
+ # log is generated. In addition, this problem will only appear in
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
+ # the training log has been printed, so it will not cause any
+ # problem. more details at
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
+ for hook in runner._hooks:
+ if isinstance(hook, LoggerHook):
+ hook.after_train_iter(runner)
+ runner.log_buffer.clear()
+
+ self._do_evaluate(runner)
+
+ def after_train_epoch(self, runner):
+ """Called after every training epoch to evaluate the results."""
+ if self.by_epoch and self._should_evaluate(runner):
+ self._do_evaluate(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ results = self.test_fn(runner.model, self.dataloader)
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to save
+ # the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
+
+ def _should_evaluate(self, runner):
+ """Judge whether to perform evaluation.
+
+ Here is the rule to judge whether to perform evaluation:
+ 1. It will not perform evaluation during the epoch/iteration interval,
+ which is determined by ``self.interval``.
+ 2. It will not perform evaluation if the start time is larger than
+ current time.
+ 3. It will not perform evaluation when current time is larger than
+ the start time but during epoch/iteration interval.
+
+ Returns:
+ bool: The flag indicating whether to perform evaluation.
+ """
+ if self.by_epoch:
+ current = runner.epoch
+ check_time = self.every_n_epochs
+ else:
+ current = runner.iter
+ check_time = self.every_n_iters
+
+ if self.start is None:
+ if not check_time(runner, self.interval):
+ # No evaluation during the interval.
+ return False
+ elif (current + 1) < self.start:
+ # No evaluation if start is larger than the current time.
+ return False
+ else:
+ # Evaluation only at epochs/iters 3, 5, 7...
+ # if start==3 and interval==2
+ if (current + 1 - self.start) % self.interval:
+ return False
+ return True
+
+ def _save_ckpt(self, runner, key_score):
+ """Save the best checkpoint.
+
+ It will compare the score according to the compare function, write
+ related information (best score, best checkpoint path) and save the
+ best checkpoint into ``work_dir``.
+ """
+ if self.by_epoch:
+ current = f'epoch_{runner.epoch + 1}'
+ cur_type, cur_time = 'epoch', runner.epoch + 1
+ else:
+ current = f'iter_{runner.iter + 1}'
+ cur_type, cur_time = 'iter', runner.iter + 1
+
+ best_score = runner.meta['hook_msgs'].get(
+ 'best_score', self.init_value_map[self.rule])
+ if self.compare_func(key_score, best_score):
+ best_score = key_score
+ runner.meta['hook_msgs']['best_score'] = best_score
+
+ if self.best_ckpt_path and self.file_client.isfile(
+ self.best_ckpt_path):
+ self.file_client.remove(self.best_ckpt_path)
+ runner.logger.info(
+ (f'The previous best checkpoint {self.best_ckpt_path} was '
+ 'removed'))
+
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
+ self.best_ckpt_path = self.file_client.join_path(
+ self.out_dir, best_ckpt_name)
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
+
+ runner.save_checkpoint(
+ self.out_dir, best_ckpt_name, create_symlink=False)
+ runner.logger.info(
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
+ runner.logger.info(
+ f'Best {self.key_indicator} is {best_score:0.4f} '
+ f'at {cur_time} {cur_type}.')
+
+ def evaluate(self, runner, results):
+ """Evaluate the results.
+
+ Args:
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
+ results (list): Output results.
+ """
+ eval_res = self.dataloader.dataset.evaluate(
+ results, logger=runner.logger, **self.eval_kwargs)
+
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+
+ if self.save_best is not None:
+ # If the performance of model is pool, the `eval_res` may be an
+ # empty dict and it will raise exception when `self.save_best` is
+ # not None. More details at
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
+ if not eval_res:
+ warnings.warn(
+ 'Since `eval_res` is an empty dict, the behavior to save '
+ 'the best checkpoint will be skipped in this evaluation.')
+ return None
+
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+
+ return None
+
+
+class DistEvalHook(EvalHook):
+ """Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader in a multi-gpu manner, and return the test results. If
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+ will be used. (default: ``None``)
+ tmpdir (str | None): Temporary directory to save the results of all
+ processes. Default: None.
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
+ Default: False.
+ broadcast_bn_buffer (bool): Whether to broadcast the
+ buffer(running_mean and running_var) of rank 0 to other rank
+ before evaluation. Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ broadcast_bn_buffer=True,
+ tmpdir=None,
+ gpu_collect=False,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+
+ if test_fn is None:
+ from custom_mmpkg.custom_mmcv.engine import multi_gpu_test
+ test_fn = multi_gpu_test
+
+ super().__init__(
+ dataloader,
+ start=start,
+ interval=interval,
+ by_epoch=by_epoch,
+ save_best=save_best,
+ rule=rule,
+ test_fn=test_fn,
+ greater_keys=greater_keys,
+ less_keys=less_keys,
+ out_dir=out_dir,
+ file_client_args=file_client_args,
+ **eval_kwargs)
+
+ self.broadcast_bn_buffer = broadcast_bn_buffer
+ self.tmpdir = tmpdir
+ self.gpu_collect = gpu_collect
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+
+ results = self.test_fn(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to
+ # save the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/hook.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e497e18e080f726fc95e62386248425a8848b3f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/hook.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from custom_mmpkg.custom_mmcv.utils import Registry, is_method_overridden
+
+HOOKS = Registry('hook')
+
+
+class Hook:
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
+ 'after_run')
+
+ def before_run(self, runner):
+ pass
+
+ def after_run(self, runner):
+ pass
+
+ def before_epoch(self, runner):
+ pass
+
+ def after_epoch(self, runner):
+ pass
+
+ def before_iter(self, runner):
+ pass
+
+ def after_iter(self, runner):
+ pass
+
+ def before_train_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def before_val_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def after_train_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def after_val_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self.before_iter(runner)
+
+ def before_val_iter(self, runner):
+ self.before_iter(runner)
+
+ def after_train_iter(self, runner):
+ self.after_iter(runner)
+
+ def after_val_iter(self, runner):
+ self.after_iter(runner)
+
+ def every_n_epochs(self, runner, n):
+ return (runner.epoch + 1) % n == 0 if n > 0 else False
+
+ def every_n_inner_iters(self, runner, n):
+ return (runner.inner_iter + 1) % n == 0 if n > 0 else False
+
+ def every_n_iters(self, runner, n):
+ return (runner.iter + 1) % n == 0 if n > 0 else False
+
+ def end_of_epoch(self, runner):
+ return runner.inner_iter + 1 == len(runner.data_loader)
+
+ def is_last_epoch(self, runner):
+ return runner.epoch + 1 == runner._max_epochs
+
+ def is_last_iter(self, runner):
+ return runner.iter + 1 == runner._max_iters
+
+ def get_triggered_stages(self):
+ trigger_stages = set()
+ for stage in Hook.stages:
+ if is_method_overridden(stage, Hook, self):
+ trigger_stages.add(stage)
+
+ # some methods will be triggered in multi stages
+ # use this dict to map method to stages.
+ method_stages_map = {
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
+ }
+
+ for method, map_stages in method_stages_map.items():
+ if is_method_overridden(method, Hook, self):
+ trigger_stages.update(map_stages)
+
+ return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/iter_timer.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/iter_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/iter_timer.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class IterTimerHook(Hook):
+
+ def before_epoch(self, runner):
+ self.t = time.time()
+
+ def before_iter(self, runner):
+ runner.log_buffer.update({'data_time': time.time() - self.t})
+
+ def after_iter(self, runner):
+ runner.log_buffer.update({'time': time.time() - self.t})
+ self.t = time.time()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/__init__.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
+from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+from .wandb import WandbLoggerHook
+
+__all__ = [
+ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/base.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/base.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+import torch
+
+from ..hook import Hook
+
+
+class LoggerHook(Hook):
+ """Base class for logger hooks.
+
+ Args:
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ self.interval = interval
+ self.ignore_last = ignore_last
+ self.reset_flag = reset_flag
+ self.by_epoch = by_epoch
+
+ @abstractmethod
+ def log(self, runner):
+ pass
+
+ @staticmethod
+ def is_scalar(val, include_np=True, include_torch=True):
+ """Tell the input variable is a scalar or not.
+
+ Args:
+ val: Input variable.
+ include_np (bool): Whether include 0-d np.ndarray as a scalar.
+ include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
+
+ Returns:
+ bool: True or False.
+ """
+ if isinstance(val, numbers.Number):
+ return True
+ elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
+ return True
+ elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
+ return True
+ else:
+ return False
+
+ def get_mode(self, runner):
+ if runner.mode == 'train':
+ if 'time' in runner.log_buffer.output:
+ mode = 'train'
+ else:
+ mode = 'val'
+ elif runner.mode == 'val':
+ mode = 'val'
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return mode
+
+ def get_epoch(self, runner):
+ if runner.mode == 'train':
+ epoch = runner.epoch + 1
+ elif runner.mode == 'val':
+ # normal val mode
+ # runner.epoch += 1 has been done before val workflow
+ epoch = runner.epoch
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return epoch
+
+ def get_iter(self, runner, inner_iter=False):
+ """Get the current training iteration step."""
+ if self.by_epoch and inner_iter:
+ current_iter = runner.inner_iter + 1
+ else:
+ current_iter = runner.iter + 1
+ return current_iter
+
+ def get_lr_tags(self, runner):
+ tags = {}
+ lrs = runner.current_lr()
+ if isinstance(lrs, dict):
+ for name, value in lrs.items():
+ tags[f'learning_rate/{name}'] = value[0]
+ else:
+ tags['learning_rate'] = lrs[0]
+ return tags
+
+ def get_momentum_tags(self, runner):
+ tags = {}
+ momentums = runner.current_momentum()
+ if isinstance(momentums, dict):
+ for name, value in momentums.items():
+ tags[f'momentum/{name}'] = value[0]
+ else:
+ tags['momentum'] = momentums[0]
+ return tags
+
+ def get_loggable_tags(self,
+ runner,
+ allow_scalar=True,
+ allow_text=False,
+ add_mode=True,
+ tags_to_skip=('time', 'data_time')):
+ tags = {}
+ for var, val in runner.log_buffer.output.items():
+ if var in tags_to_skip:
+ continue
+ if self.is_scalar(val) and not allow_scalar:
+ continue
+ if isinstance(val, str) and not allow_text:
+ continue
+ if add_mode:
+ var = f'{self.get_mode(runner)}/{var}'
+ tags[var] = val
+ tags.update(self.get_lr_tags(runner))
+ tags.update(self.get_momentum_tags(runner))
+ return tags
+
+ def before_run(self, runner):
+ for hook in runner.hooks[::-1]:
+ if isinstance(hook, LoggerHook):
+ hook.reset_flag = True
+ break
+
+ def before_epoch(self, runner):
+ runner.log_buffer.clear() # clear logs of last epoch
+
+ def after_train_iter(self, runner):
+ if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif self.end_of_epoch(runner) and not self.ignore_last:
+ # not precise but more stable
+ runner.log_buffer.average(self.interval)
+
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_train_epoch(self, runner):
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_val_epoch(self, runner):
+ runner.log_buffer.average()
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/dvclive.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class DvcliveLoggerHook(LoggerHook):
+ """Class to log metrics with dvclive.
+
+ It requires `dvclive`_ to be installed.
+
+ Args:
+ path (str): Directory where dvclive will write TSV log files.
+ interval (int): Logging interval (every k iterations).
+ Default 10.
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ Default: True.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ Default: True.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ Default: True.
+
+ .. _dvclive:
+ https://dvc.org/doc/dvclive
+ """
+
+ def __init__(self,
+ path,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ by_epoch=True):
+
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.path = path
+ self.import_dvclive()
+
+ def import_dvclive(self):
+ try:
+ import dvclive
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install dvclive" to install dvclive')
+ self.dvclive = dvclive
+
+ @master_only
+ def before_run(self, runner):
+ self.dvclive.init(self.path)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for k, v in tags.items():
+ self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/mlflow.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/mlflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/mlflow.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class MlflowLoggerHook(LoggerHook):
+
+ def __init__(self,
+ exp_name=None,
+ tags=None,
+ log_model=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ """Class to log metrics and (optionally) a trained model to MLflow.
+
+ It requires `MLflow`_ to be installed.
+
+ Args:
+ exp_name (str, optional): Name of the experiment to be used.
+ Default None.
+ If not None, set the active experiment.
+ If experiment does not exist, an experiment with provided name
+ will be created.
+ tags (dict of str: str, optional): Tags for the current run.
+ Default None.
+ If not None, set tags for the current run.
+ log_model (bool, optional): Whether to log an MLflow artifact.
+ Default True.
+ If True, log runner.model as an MLflow artifact
+ for the current run.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _MLflow:
+ https://www.mlflow.org/docs/latest/index.html
+ """
+ super(MlflowLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_mlflow()
+ self.exp_name = exp_name
+ self.tags = tags
+ self.log_model = log_model
+
+ def import_mlflow(self):
+ try:
+ import mlflow
+ import mlflow.pytorch as mlflow_pytorch
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install mlflow" to install mlflow')
+ self.mlflow = mlflow
+ self.mlflow_pytorch = mlflow_pytorch
+
+ @master_only
+ def before_run(self, runner):
+ super(MlflowLoggerHook, self).before_run(runner)
+ if self.exp_name is not None:
+ self.mlflow.set_experiment(self.exp_name)
+ if self.tags is not None:
+ self.mlflow.set_tags(self.tags)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ self.mlflow.log_metrics(tags, step=self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.log_model:
+ self.mlflow_pytorch.log_model(runner.model, 'models')
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/neptune.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class NeptuneLoggerHook(LoggerHook):
+ """Class to log metrics to NeptuneAI.
+
+ It requires `neptune-client` to be installed.
+
+ Args:
+ init_kwargs (dict): a dict contains the initialization keys as below:
+ - project (str): Name of a project in a form of
+ namespace/project_name. If None, the value of
+ NEPTUNE_PROJECT environment variable will be taken.
+ - api_token (str): User’s API token.
+ If None, the value of NEPTUNE_API_TOKEN environment
+ variable will be taken. Note: It is strongly recommended
+ to use NEPTUNE_API_TOKEN environment variable rather than
+ placing your API token in plain text in your source code.
+ - name (str, optional, default is 'Untitled'): Editable name of
+ the run. Name is displayed in the run's Details and in
+ Runs table as a column.
+ Check https://docs.neptune.ai/api-reference/neptune#init for
+ more init arguments.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _NeptuneAI:
+ https://docs.neptune.ai/you-should-know/logging-metadata
+ """
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ with_step=True,
+ by_epoch=True):
+
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_neptune()
+ self.init_kwargs = init_kwargs
+ self.with_step = with_step
+
+ def import_neptune(self):
+ try:
+ import neptune.new as neptune
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install neptune-client" to install neptune')
+ self.neptune = neptune
+ self.run = None
+
+ @master_only
+ def before_run(self, runner):
+ if self.init_kwargs:
+ self.run = self.neptune.init(**self.init_kwargs)
+ else:
+ self.run = self.neptune.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for tag_name, tag_value in tags.items():
+ if self.with_step:
+ self.run[tag_name].log(
+ tag_value, step=self.get_iter(runner))
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.run[tag_name].log(tags)
+
+ @master_only
+ def after_run(self, runner):
+ self.run.stop()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/pavi.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/pavi.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5f14224cc4762cd1ef18a5d3b49d023f22a1dc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/pavi.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import torch
+import yaml
+
+import custom_mmpkg.custom_mmcv as mmcv
+from ....parallel.utils import is_module_wrapper
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class PaviLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ add_graph=False,
+ add_last_ckpt=False,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True,
+ img_key='img_info'):
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.init_kwargs = init_kwargs
+ self.add_graph = add_graph
+ self.add_last_ckpt = add_last_ckpt
+ self.img_key = img_key
+
+ @master_only
+ def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
+ try:
+ from pavi import SummaryWriter
+ except ImportError:
+ raise ImportError('Please run "pip install pavi" to install pavi.')
+
+ self.run_name = runner.work_dir.split('/')[-1]
+
+ if not self.init_kwargs:
+ self.init_kwargs = dict()
+ self.init_kwargs['name'] = self.run_name
+ self.init_kwargs['model'] = runner._model_name
+ if runner.meta is not None:
+ if 'config_dict' in runner.meta:
+ config_dict = runner.meta['config_dict']
+ assert isinstance(
+ config_dict,
+ dict), ('meta["config_dict"] has to be of a dict, '
+ f'but got {type(config_dict)}')
+ elif 'config_file' in runner.meta:
+ config_file = runner.meta['config_file']
+ config_dict = dict(mmcv.Config.fromfile(config_file))
+ else:
+ config_dict = None
+ if config_dict is not None:
+ # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
+ # to properly set up the progress bar.
+ config_dict = config_dict.copy()
+ config_dict.setdefault('max_iter', runner.max_iters)
+ # non-serializable values are first converted in
+ # mmcv.dump to json
+ config_dict = json.loads(
+ mmcv.dump(config_dict, file_format='json'))
+ session_text = yaml.dump(config_dict)
+ self.init_kwargs['session_text'] = session_text
+ self.writer = SummaryWriter(**self.init_kwargs)
+
+ def get_step(self, runner):
+ """Get the total training step/epoch."""
+ if self.get_mode(runner) == 'val' and self.by_epoch:
+ return self.get_epoch(runner)
+ else:
+ return self.get_iter(runner)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, add_mode=False)
+ if tags:
+ self.writer.add_scalars(
+ self.get_mode(runner), tags, self.get_step(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.add_last_ckpt:
+ ckpt_path = osp.join(runner.work_dir, 'latest.pth')
+ if osp.islink(ckpt_path):
+ ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
+
+ if osp.isfile(ckpt_path):
+ # runner.epoch += 1 has been done before `after_run`.
+ iteration = runner.epoch if self.by_epoch else runner.iter
+ return self.writer.add_snapshot_file(
+ tag=self.run_name,
+ snapshot_file_path=ckpt_path,
+ iteration=iteration)
+
+ # flush the buffer and send a task ending signal to Pavi
+ self.writer.close()
+
+ @master_only
+ def before_epoch(self, runner):
+ if runner.epoch == 0 and self.add_graph:
+ if is_module_wrapper(runner.model):
+ _model = runner.model.module
+ else:
+ _model = runner.model
+ device = next(_model.parameters()).device
+ data = next(iter(runner.data_loader))
+ image = data[self.img_key][0:1].to(device)
+ with torch.no_grad():
+ self.writer.add_graph(_model, image)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/tensorboard.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9c727ff9776c5c8d41838f2f0676a4db56186b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/tensorboard.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, digit_version
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TensorboardLoggerHook(LoggerHook):
+
+ def __init__(self,
+ log_dir=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.log_dir = log_dir
+
+ @master_only
+ def before_run(self, runner):
+ super(TensorboardLoggerHook, self).before_run(runner)
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.1')):
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError('Please install tensorboardX to use '
+ 'TensorboardLoggerHook.')
+ else:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ 'the dependencies to use torch.utils.tensorboard '
+ '(applicable to PyTorch 1.1 or higher)')
+
+ if self.log_dir is None:
+ self.log_dir = osp.join(runner.work_dir, 'tf_logs')
+ self.writer = SummaryWriter(self.log_dir)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, allow_text=True)
+ for tag, val in tags.items():
+ if isinstance(val, str):
+ self.writer.add_text(tag, val, self.get_iter(runner))
+ else:
+ self.writer.add_scalar(tag, val, self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ self.writer.close()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/text.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea12c02a96d590493ae48055196bb28798bfefff
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/text.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+
+import custom_mmpkg.custom_mmcv as mmcv
+from custom_mmpkg.custom_mmcv.fileio.file_client import FileClient
+from custom_mmpkg.custom_mmcv.utils import is_tuple_of, scandir
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TextLoggerHook(LoggerHook):
+ """Logger hook in text.
+
+ In this logger hook, the information will be printed on terminal and
+ saved in json file.
+
+ Args:
+ by_epoch (bool, optional): Whether EpochBasedRunner is used.
+ Default: True.
+ interval (int, optional): Logging interval (every k iterations).
+ Default: 10.
+ ignore_last (bool, optional): Ignore the log of last iterations in each
+ epoch if less than :attr:`interval`. Default: True.
+ reset_flag (bool, optional): Whether to clear the output buffer after
+ logging. Default: False.
+ interval_exp_name (int, optional): Logging interval for experiment
+ name. This feature is to help users conveniently get the experiment
+ information from screen or log file. Default: 1000.
+ out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
+ If ``out_dir`` is specified, logs will be copied to a new directory
+ which is the concatenation of ``out_dir`` and the last level
+ directory of ``runner.work_dir``. Default: None.
+ `New in version 1.3.16.`
+ out_suffix (str or tuple[str], optional): Those filenames ending with
+ ``out_suffix`` will be copied to ``out_dir``.
+ Default: ('.log.json', '.log', '.py').
+ `New in version 1.3.16.`
+ keep_local (bool, optional): Whether to keep local log when
+ :attr:`out_dir` is specified. If False, the local log will be
+ removed. Default: True.
+ `New in version 1.3.16.`
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ interval_exp_name=1000,
+ out_dir=None,
+ out_suffix=('.log.json', '.log', '.py'),
+ keep_local=True,
+ file_client_args=None):
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.by_epoch = by_epoch
+ self.time_sec_tot = 0
+ self.interval_exp_name = interval_exp_name
+
+ if out_dir is None and file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" when `out_dir` is not'
+ 'specified.')
+ self.out_dir = out_dir
+
+ if not (out_dir is None or isinstance(out_dir, str)
+ or is_tuple_of(out_dir, str)):
+ raise TypeError('out_dir should be "None" or string or tuple of '
+ 'string, but got {out_dir}')
+ self.out_suffix = out_suffix
+
+ self.keep_local = keep_local
+ self.file_client_args = file_client_args
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(file_client_args,
+ self.out_dir)
+
+ def before_run(self, runner):
+ super(TextLoggerHook, self).before_run(runner)
+
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # The final `self.out_dir` is the concatenation of `self.out_dir`
+ # and the last level directory of `runner.work_dir`
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'Text logs will be saved to {self.out_dir} by '
+ f'{self.file_client.name} after the training process.'))
+
+ self.start_iter = runner.iter
+ self.json_log_path = osp.join(runner.work_dir,
+ f'{runner.timestamp}.log.json')
+ if runner.meta is not None:
+ self._dump_log(runner.meta, runner)
+
+ def _get_max_memory(self, runner):
+ device = getattr(runner.model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ if runner.world_size > 1:
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+ return mem_mb.item()
+
+ def _log_info(self, log_dict, runner):
+ # print exp name for users to distinguish experiments
+ # at every ``interval_exp_name`` iterations and the end of each epoch
+ if runner.meta is not None and 'exp_name' in runner.meta:
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
+ self.by_epoch and self.end_of_epoch(runner)):
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
+ runner.logger.info(exp_info)
+
+ if log_dict['mode'] == 'train':
+ if isinstance(log_dict['lr'], dict):
+ lr_str = []
+ for k, val in log_dict['lr'].items():
+ lr_str.append(f'lr_{k}: {val:.3e}')
+ lr_str = ' '.join(lr_str)
+ else:
+ lr_str = f'lr: {log_dict["lr"]:.3e}'
+
+ # by epoch: Epoch [4][100/1000]
+ # by iter: Iter [100/100000]
+ if self.by_epoch:
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
+ else:
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
+ log_str += f'{lr_str}, '
+
+ if 'time' in log_dict.keys():
+ self.time_sec_tot += (log_dict['time'] * self.interval)
+ time_sec_avg = self.time_sec_tot / (
+ runner.iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ log_str += f'eta: {eta_str}, '
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
+ f'data_time: {log_dict["data_time"]:.3f}, '
+ # statistic memory
+ if torch.cuda.is_available():
+ log_str += f'memory: {log_dict["memory"]}, '
+ else:
+ # val/test time
+ # here 1000 is the length of the val dataloader
+ # by epoch: Epoch[val] [4][1000]
+ # by iter: Iter[val] [1000]
+ if self.by_epoch:
+ log_str = f'Epoch({log_dict["mode"]}) ' \
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
+ else:
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
+
+ log_items = []
+ for name, val in log_dict.items():
+ # TODO: resolve this hack
+ # these items have been in log_str
+ if name in [
+ 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
+ 'memory', 'epoch'
+ ]:
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ', '.join(log_items)
+
+ runner.logger.info(log_str)
+
+ def _dump_log(self, log_dict, runner):
+ # dump log in json format
+ json_log = OrderedDict()
+ for k, v in log_dict.items():
+ json_log[k] = self._round_float(v)
+ # only append log at last line
+ if runner.rank == 0:
+ with open(self.json_log_path, 'a+') as f:
+ mmcv.dump(json_log, f, file_format='json')
+ f.write('\n')
+
+ def _round_float(self, items):
+ if isinstance(items, list):
+ return [self._round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 5)
+ else:
+ return items
+
+ def log(self, runner):
+ if 'eval_iter_num' in runner.log_buffer.output:
+ # this doesn't modify runner.iter and is regardless of by_epoch
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
+ else:
+ cur_iter = self.get_iter(runner, inner_iter=True)
+
+ log_dict = OrderedDict(
+ mode=self.get_mode(runner),
+ epoch=self.get_epoch(runner),
+ iter=cur_iter)
+
+ # only record lr of the first param group
+ cur_lr = runner.current_lr()
+ if isinstance(cur_lr, list):
+ log_dict['lr'] = cur_lr[0]
+ else:
+ assert isinstance(cur_lr, dict)
+ log_dict['lr'] = {}
+ for k, lr_ in cur_lr.items():
+ assert isinstance(lr_, list)
+ log_dict['lr'].update({k: lr_[0]})
+
+ if 'time' in runner.log_buffer.output:
+ # statistic memory
+ if torch.cuda.is_available():
+ log_dict['memory'] = self._get_max_memory(runner)
+
+ log_dict = dict(log_dict, **runner.log_buffer.output)
+
+ self._log_info(log_dict, runner)
+ self._dump_log(log_dict, runner)
+ return log_dict
+
+ def after_run(self, runner):
+ # copy or upload logs to self.out_dir
+ if self.out_dir is not None:
+ for filename in scandir(runner.work_dir, self.out_suffix, True):
+ local_filepath = osp.join(runner.work_dir, filename)
+ out_filepath = self.file_client.join_path(
+ self.out_dir, filename)
+ with open(local_filepath, 'r') as f:
+ self.file_client.put_text(f.read(), out_filepath)
+
+ runner.logger.info(
+ (f'The file {local_filepath} has been uploaded to '
+ f'{out_filepath}.'))
+
+ if not self.keep_local:
+ os.remove(local_filepath)
+ runner.logger.info(
+ (f'{local_filepath} was removed due to the '
+ '`self.keep_local=False`'))
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/wandb.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/wandb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/logger/wandb.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class WandbLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ commit=True,
+ by_epoch=True,
+ with_step=True):
+ super(WandbLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_wandb()
+ self.init_kwargs = init_kwargs
+ self.commit = commit
+ self.with_step = with_step
+
+ def import_wandb(self):
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install wandb" to install wandb')
+ self.wandb = wandb
+
+ @master_only
+ def before_run(self, runner):
+ super(WandbLoggerHook, self).before_run(runner)
+ if self.wandb is None:
+ self.import_wandb()
+ if self.init_kwargs:
+ self.wandb.init(**self.init_kwargs)
+ else:
+ self.wandb.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ if self.with_step:
+ self.wandb.log(
+ tags, step=self.get_iter(runner), commit=self.commit)
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.wandb.log(tags, commit=self.commit)
+
+ @master_only
+ def after_run(self, runner):
+ self.wandb.join()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/lr_updater.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/lr_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..f375932319cdbce2d50a7fc60b68ea750a60bb70
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/lr_updater.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from math import cos, pi
+
+import custom_mmpkg.custom_mmcv as mmcv
+from .hook import HOOKS, Hook
+
+
+class LrUpdaterHook(Hook):
+ """LR Scheduler in MMCV.
+
+ Args:
+ by_epoch (bool): LR changes epoch by epoch
+ warmup (string): Type of warmup used. It can be None(use no warmup),
+ 'constant', 'linear' or 'exp'
+ warmup_iters (int): The number of iterations or epochs that warmup
+ lasts
+ warmup_ratio (float): LR used at the beginning of warmup equals to
+ warmup_ratio * initial_lr
+ warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
+ means the number of epochs that warmup lasts, otherwise means the
+ number of iteration that warmup lasts
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.1,
+ warmup_by_epoch=False):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_ratio" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.warmup_by_epoch = warmup_by_epoch
+
+ if self.warmup_by_epoch:
+ self.warmup_epochs = self.warmup_iters
+ self.warmup_iters = None
+ else:
+ self.warmup_epochs = None
+
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+
+ def _set_lr(self, runner, lr_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+ param_group['lr'] = lr
+ else:
+ for param_group, lr in zip(runner.optimizer.param_groups,
+ lr_groups):
+ param_group['lr'] = lr
+
+ def get_lr(self, runner, base_lr):
+ raise NotImplementedError
+
+ def get_regular_lr(self, runner):
+ if isinstance(runner.optimizer, dict):
+ lr_groups = {}
+ for k in runner.optimizer.keys():
+ _lr_group = [
+ self.get_lr(runner, _base_lr)
+ for _base_lr in self.base_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+
+ return lr_groups
+ else:
+ return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
+
+ def get_warmup_lr(self, cur_iters):
+
+ def _get_warmup_lr(cur_iters, regular_lr):
+ if self.warmup == 'constant':
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+
+ if isinstance(self.regular_lr, dict):
+ lr_groups = {}
+ for key, regular_lr in self.regular_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.regular_lr)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ _base_lr = [
+ group['initial_lr'] for group in optim.param_groups
+ ]
+ self.base_lr.update({k: _base_lr})
+ else:
+ for group in runner.optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ self.base_lr = [
+ group['initial_lr'] for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if self.warmup_iters is None:
+ epoch_len = len(runner.data_loader)
+ self.warmup_iters = self.warmup_epochs * epoch_len
+
+ if not self.by_epoch:
+ return
+
+ self.regular_lr = self.get_regular_lr(runner)
+ self._set_lr(runner, self.regular_lr)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_lr = self.get_regular_lr(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+
+
+@HOOKS.register_module()
+class FixedLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, **kwargs):
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ return base_lr
+
+
+@HOOKS.register_module()
+class StepLrUpdaterHook(LrUpdaterHook):
+ """Step LR scheduler with min_lr clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the LR. If an int value is given,
+ regard it as the decay interval. If a list is given, decay LR at
+ these steps.
+ gamma (float, optional): Decay LR ratio. Default: 0.1.
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
+ is lower than `min_lr`, it will be clipped to this value. If None
+ is given, we don't perform lr clipping. Default: None.
+ """
+
+ def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_lr = min_lr
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ lr = base_lr * (self.gamma**exp)
+ if self.min_lr is not None:
+ # clip to a minimum value
+ lr = max(lr, self.min_lr)
+ return lr
+
+
+@HOOKS.register_module()
+class ExpLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, **kwargs):
+ self.gamma = gamma
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * self.gamma**progress
+
+
+@HOOKS.register_module()
+class PolyLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, power=1., min_lr=0., **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ coeff = (1 - progress / max_progress)**self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+@HOOKS.register_module()
+class InvLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, power=1., **kwargs):
+ self.gamma = gamma
+ self.power = power
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * (1 + self.gamma * progress)**(-self.power)
+
+
+@HOOKS.register_module()
+class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
+ """Flat + Cosine lr schedule.
+
+ Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
+
+ Args:
+ start_percent (float): When to start annealing the learning rate
+ after the percentage of the total training steps.
+ The value should be in range [0, 1).
+ Default: 0.75
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ start_percent=0.75,
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ if start_percent < 0 or start_percent > 1 or not isinstance(
+ start_percent, float):
+ raise ValueError(
+ 'expected float between 0 and 1 start_percent, but '
+ f'got {start_percent}')
+ self.start_percent = start_percent
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ start = round(runner.max_epochs * self.start_percent)
+ progress = runner.epoch - start
+ max_progress = runner.max_epochs - start
+ else:
+ start = round(runner.max_iters * self.start_percent)
+ progress = runner.iter - start
+ max_progress = runner.max_iters - start
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ if progress < 0:
+ return base_lr
+ else:
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class CosineRestartLrUpdaterHook(LrUpdaterHook):
+ """Cosine annealing with restarts learning rate scheme.
+
+ Args:
+ periods (list[int]): Periods for each cosine anneling cycle.
+ restart_weights (list[float], optional): Restart weights at each
+ restart iteration. Default: [1].
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ periods,
+ restart_weights=[1],
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.periods = periods
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ self.restart_weights = restart_weights
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
+
+ self.cumulative_periods = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ else:
+ progress = runner.iter
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ idx = get_position_from_periods(progress, self.cumulative_periods)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+ current_periods = self.periods[idx]
+
+ alpha = min((progress - nearest_restart) / current_periods, 1)
+ return annealing_cos(base_lr, target_lr, alpha, current_weight)
+
+
+def get_position_from_periods(iteration, cumulative_periods):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_periods = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 3.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_periods (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_periods):
+ if iteration < period:
+ return i
+ raise ValueError(f'Current iteration {iteration} exceeds '
+ f'cumulative_periods {cumulative_periods}')
+
+
+@HOOKS.register_module()
+class CyclicLrUpdaterHook(LrUpdaterHook):
+ """Cyclic LR Scheduler.
+
+ Implement the cyclical learning rate policy (CLR) described in
+ https://arxiv.org/pdf/1506.01186.pdf
+
+ Different from the original paper, we use cosine annealing rather than
+ triangular policy inside a cycle. This improves the performance in the
+ 3D detection area.
+
+ Args:
+ by_epoch (bool): Whether to update LR by epoch.
+ target_ratio (tuple[float]): Relative ratio of the highest LR and the
+ lowest LR to the initial LR.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of LR in
+ the total cycle.
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing. Default: 'cos'.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(10, 1e-4),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ anneal_strategy='cos',
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.lr_phases = [] # init lr_phases
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicLrUpdaterHook, self).before_run(runner)
+ # initiate lr_phases
+ # total lr_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.lr_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.lr_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.lr_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return self.anneal_func(base_lr * start_ratio,
+ base_lr * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleLrUpdaterHook(LrUpdaterHook):
+ """One Cycle LR Scheduler.
+
+ The 1cycle learning rate policy changes the learning rate after every
+ batch. The one cycle learning rate policy is described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ Args:
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group.
+ total_steps (int, optional): The total number of steps in the cycle.
+ Note that if a value is not provided here, it will be the max_iter
+ of runner. Default: None.
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ div_factor (float): Determines the initial learning rate via
+ initial_lr = max_lr/div_factor
+ Default: 25
+ final_div_factor (float): Determines the minimum learning rate via
+ min_lr = initial_lr/final_div_factor
+ Default: 1e4
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25,
+ final_div_factor=1e4,
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch = False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(max_lr, (numbers.Number, list, dict)):
+ raise ValueError('the type of max_lr must be the one of list or '
+ f'dict, but got {type(max_lr)}')
+ self._max_lr = max_lr
+ if total_steps is not None:
+ if not isinstance(total_steps, int):
+ raise ValueError('the type of total_steps must be int, but'
+ f'got {type(total_steps)}')
+ self.total_steps = total_steps
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.div_factor = div_factor
+ self.final_div_factor = final_div_factor
+ self.three_phase = three_phase
+ self.lr_phases = [] # init lr_phases
+ super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if hasattr(self, 'total_steps'):
+ total_steps = self.total_steps
+ else:
+ total_steps = runner.max_iters
+ if total_steps < runner.max_iters:
+ raise ValueError(
+ 'The total steps must be greater than or equal to max '
+ f'iterations {runner.max_iters} of runner, but total steps '
+ f'is {total_steps}.')
+
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ _max_lr = format_param(k, optim, self._max_lr)
+ self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(optim.param_groups, self.base_lr[k]):
+ group.setdefault('initial_lr', lr)
+ else:
+ k = type(runner.optimizer).__name__
+ _max_lr = format_param(k, runner.optimizer, self._max_lr)
+ self.base_lr = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
+ group.setdefault('initial_lr', lr)
+
+ if self.three_phase:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append([
+ float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
+ ])
+ self.lr_phases.append(
+ [total_steps - 1, 1, 1 / self.final_div_factor])
+ else:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append(
+ [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
+ if curr_iter <= end_iter:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
+ pct)
+ break
+ start_iter = end_iter
+ return lr
+
+
+def annealing_cos(start, end, factor, weight=1):
+ """Calculate annealing cos learning rate.
+
+ Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
+ percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the cosine annealing.
+ end (float): The ending learing rate of the cosine annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ weight (float, optional): The combination factor of `start` and `end`
+ when calculating the actual starting learning rate. Default to 1.
+ """
+ cos_out = cos(pi * factor) + 1
+ return end + 0.5 * weight * (start - end) * cos_out
+
+
+def annealing_linear(start, end, factor):
+ """Calculate annealing linear learning rate.
+
+ Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the linear annealing.
+ end (float): The ending learing rate of the linear annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ """
+ return start + (end - start) * factor
+
+
+def format_param(name, optim, param):
+ if isinstance(param, numbers.Number):
+ return [param] * len(optim.param_groups)
+ elif isinstance(param, (list, tuple)): # multi param groups
+ if len(param) != len(optim.param_groups):
+ raise ValueError(f'expected {len(optim.param_groups)} '
+ f'values for {name}, got {len(param)}')
+ return param
+ else: # multi optimizers
+ if name not in param:
+ raise KeyError(f'{name} is not found in {param.keys()}')
+ return param[name]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/memory.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/memory.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EmptyCacheHook(Hook):
+
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+ self._before_epoch = before_epoch
+ self._after_epoch = after_epoch
+ self._after_iter = after_iter
+
+ def after_iter(self, runner):
+ if self._after_iter:
+ torch.cuda.empty_cache()
+
+ def before_epoch(self, runner):
+ if self._before_epoch:
+ torch.cuda.empty_cache()
+
+ def after_epoch(self, runner):
+ if self._after_epoch:
+ torch.cuda.empty_cache()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/momentum_updater.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/momentum_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b6c7c531a24603cbfee463f23e0c310cbfff41
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/momentum_updater.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import custom_mmpkg.custom_mmcv as mmcv
+from .hook import HOOKS, Hook
+from .lr_updater import annealing_cos, annealing_linear, format_param
+
+
+class MomentumUpdaterHook(Hook):
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.9):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_momentum" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+
+ self.base_momentum = [] # initial momentum for all param groups
+ self.regular_momentum = [
+ ] # expected momentum if no warming up is performed
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, base_momentum):
+ raise NotImplementedError
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k in runner.optimizer.keys():
+ _momentum_group = [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum[k]
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ return [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum
+ ]
+
+ def get_warmup_momentum(self, cur_iters):
+
+ def _get_warmup_momentum(cur_iters, regular_momentum):
+ if self.warmup == 'constant':
+ warmup_momentum = [
+ _momentum / self.warmup_ratio
+ for _momentum in self.regular_momentum
+ ]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_momentum = [
+ _momentum / (1 - k) for _momentum in self.regular_mom
+ ]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_momentum = [
+ _momentum / k for _momentum in self.regular_mom
+ ]
+ return warmup_momentum
+
+ if isinstance(self.regular_momentum, dict):
+ momentum_groups = {}
+ for key, regular_momentum in self.regular_momentum.items():
+ momentum_groups[key] = _get_warmup_momentum(
+ cur_iters, regular_momentum)
+ return momentum_groups
+ else:
+ return _get_warmup_momentum(cur_iters, self.regular_momentum)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint,
+ # if 'initial_momentum' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_momentum = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ _base_momentum = [
+ group['initial_momentum'] for group in optim.param_groups
+ ]
+ self.base_momentum.update({k: _base_momentum})
+ else:
+ for group in runner.optimizer.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ self.base_momentum = [
+ group['initial_momentum']
+ for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ self.regular_mom = self.get_regular_momentum(runner)
+ self._set_momentum(runner, self.regular_mom)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_mom = self.get_regular_momentum(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+
+
+@HOOKS.register_module()
+class StepMomentumUpdaterHook(MomentumUpdaterHook):
+ """Step momentum scheduler with min value clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the momentum. If an int value is
+ given, regard it as the decay interval. If a list is given, decay
+ momentum at these steps.
+ gamma (float, optional): Decay momentum ratio. Default: 0.5.
+ min_momentum (float, optional): Minimum momentum value to keep. If
+ momentum after decay is lower than this value, it will be clipped
+ accordingly. If None is given, we don't perform lr clipping.
+ Default: None.
+ """
+
+ def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_momentum = min_momentum
+ super(StepMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ momentum = base_momentum * (self.gamma**exp)
+ if self.min_momentum is not None:
+ # clip to a minimum value
+ momentum = max(momentum, self.min_momentum)
+ return momentum
+
+
+@HOOKS.register_module()
+class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
+
+ def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
+ assert (min_momentum is None) ^ (min_momentum_ratio is None)
+ self.min_momentum = min_momentum
+ self.min_momentum_ratio = min_momentum_ratio
+ super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ if self.min_momentum_ratio is not None:
+ target_momentum = base_momentum * self.min_momentum_ratio
+ else:
+ target_momentum = self.min_momentum
+ return annealing_cos(base_momentum, target_momentum,
+ progress / max_progress)
+
+
+@HOOKS.register_module()
+class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
+ """Cyclic momentum Scheduler.
+
+ Implement the cyclical momentum scheduler policy described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ This momentum scheduler usually used together with the CyclicLRUpdater
+ to improve the performance in the 3D detection area.
+
+ Attributes:
+ target_ratio (tuple[float]): Relative ratio of the lowest momentum and
+ the highest momentum to the initial momentum.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of momentum
+ in the total cycle.
+ by_epoch (bool): Whether to update momentum by epoch.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(0.85 / 0.95, 1),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.momentum_phases = [] # init momentum_phases
+ # currently only support by_epoch=False
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicMomentumUpdaterHook, self).before_run(runner)
+ # initiate momentum_phases
+ # total momentum_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.momentum_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.momentum_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_momentum(self, runner, base_momentum):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.momentum_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return annealing_cos(base_momentum * start_ratio,
+ base_momentum * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
+ """OneCycle momentum Scheduler.
+
+ This momentum scheduler usually used together with the OneCycleLrUpdater
+ to improve the performance.
+
+ Args:
+ base_momentum (float or list): Lower momentum boundaries in the cycle
+ for each parameter group. Note that momentum is cycled inversely
+ to learning rate; at the peak of a cycle, momentum is
+ 'base_momentum' and learning rate is 'max_lr'.
+ Default: 0.85
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ Note that momentum is cycled inversely
+ to learning rate; at the start of a cycle, momentum is
+ 'max_momentum' and learning rate is 'base_lr'
+ Default: 0.95
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch=False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(base_momentum, (float, list, dict)):
+ raise ValueError('base_momentum must be the type among of float,'
+ 'list or dict.')
+ self._base_momentum = base_momentum
+ if not isinstance(max_momentum, (float, list, dict)):
+ raise ValueError('max_momentum must be the type among of float,'
+ 'list or dict.')
+ self._max_momentum = max_momentum
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('Expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must by one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.three_phase = three_phase
+ self.momentum_phases = [] # init momentum_phases
+ super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(
+ optim.param_groups, _base_momentum, _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+ else:
+ optim = runner.optimizer
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ k = type(optim).__name__
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(optim.param_groups,
+ _base_momentum,
+ _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+
+ if self.three_phase:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter':
+ float(2 * self.pct_start * runner.max_iters) - 2,
+ 'start_momentum':
+ 'base_momentum',
+ 'end_momentum':
+ 'max_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'max_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+ else:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'base_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, param_group):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, phase in enumerate(self.momentum_phases):
+ end_iter = phase['end_iter']
+ if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ momentum = self.anneal_func(
+ param_group[phase['start_momentum']],
+ param_group[phase['end_momentum']], pct)
+ break
+ start_iter = end_iter
+ return momentum
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k, optim in runner.optimizer.items():
+ _momentum_group = [
+ self.get_momentum(runner, param_group)
+ for param_group in optim.param_groups
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ momentum_groups = []
+ for param_group in runner.optimizer.param_groups:
+ momentum_groups.append(self.get_momentum(runner, param_group))
+ return momentum_groups
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/optimizer.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f111733b6d37a86dc396442e39b67a8880c99a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/optimizer.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+from itertools import chain
+
+from torch.nn.utils import clip_grad
+
+from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+from ..dist_utils import allreduce_grads
+from ..fp16_utils import LossScaler, wrap_fp16_model
+from .hook import HOOKS, Hook
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ from torch.cuda.amp import GradScaler
+except ImportError:
+ pass
+
+
+@HOOKS.register_module()
+class OptimizerHook(Hook):
+
+ def __init__(self, grad_clip=None):
+ self.grad_clip = grad_clip
+
+ def clip_grads(self, params):
+ params = list(
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
+ if len(params) > 0:
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs['loss'].backward()
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+
+
+@HOOKS.register_module()
+class GradientCumulativeOptimizerHook(OptimizerHook):
+ """Optimizer Hook implements multi-iters gradient cumulating.
+
+ Args:
+ cumulative_iters (int, optional): Num of gradient cumulative iters.
+ The optimizer will step every `cumulative_iters` iters.
+ Defaults to 1.
+
+ Examples:
+ >>> # Use cumulative_iters to simulate a large batch size
+ >>> # It is helpful when the hardware cannot handle a large batch size.
+ >>> loader = DataLoader(data, batch_size=64)
+ >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+ >>> # almost equals to
+ >>> loader = DataLoader(data, batch_size=256)
+ >>> optim_hook = OptimizerHook()
+ """
+
+ def __init__(self, cumulative_iters=1, **kwargs):
+ super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+
+ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+ f'cumulative_iters only accepts positive int, but got ' \
+ f'{type(cumulative_iters)} instead.'
+
+ self.cumulative_iters = cumulative_iters
+ self.divisible_iters = 0
+ self.remainder_iters = 0
+ self.initialized = False
+
+ def has_batch_norm(self, module):
+ if isinstance(module, _BatchNorm):
+ return True
+ for m in module.children():
+ if self.has_batch_norm(m):
+ return True
+ return False
+
+ def _init(self, runner):
+ if runner.iter % self.cumulative_iters != 0:
+ runner.logger.warning(
+ 'Resume iter number is not divisible by cumulative_iters in '
+ 'GradientCumulativeOptimizerHook, which means the gradient of '
+ 'some iters is lost and the result may be influenced slightly.'
+ )
+
+ if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+ runner.logger.warning(
+ 'GradientCumulativeOptimizerHook may slightly decrease '
+ 'performance if the model has BatchNorm layers.')
+
+ residual_iters = runner.max_iters - runner.iter
+
+ self.divisible_iters = (
+ residual_iters // self.cumulative_iters * self.cumulative_iters)
+ self.remainder_iters = residual_iters - self.divisible_iters
+
+ self.initialized = True
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
+
+
+if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (using PyTorch's implementation).
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of GradScalar.
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+ implementation of GradScaler. If you use a dict version of
+ loss_scale to create GradScaler, please refer to:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+ for the parameters.
+
+ Examples:
+ >>> loss_scale = dict(
+ ... init_scale=65536.0,
+ ... growth_factor=2.0,
+ ... backoff_factor=0.5,
+ ... growth_interval=2000
+ ... )
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ self._scale_update_param = None
+ if loss_scale == 'dynamic':
+ self.loss_scaler = GradScaler()
+ elif isinstance(loss_scale, float):
+ self._scale_update_param = loss_scale
+ self.loss_scaler = GradScaler(init_scale=loss_scale)
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = GradScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training."""
+ # wrap model mode to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer to
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients.
+ 3. Unscale the optimizer’s gradient tensors.
+ 4. Call optimizer.step() and update scale factor.
+ 5. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+ self.loss_scaler.scale(runner.outputs['loss']).backward()
+ self.loss_scaler.unscale_(runner.optimizer)
+ # grad clip
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using PyTorch's implementation) implements
+ multi-iters gradient cumulating.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ self.loss_scaler.scale(loss).backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ self.loss_scaler.unscale_(runner.optimizer)
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+else:
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (mmcv's implementation).
+
+ The steps of fp16 optimizer is as follows.
+ 1. Scale the loss value.
+ 2. BP in the fp16 model.
+ 2. Copy gradients from fp16 model to fp32 weights.
+ 3. Update fp32 weights.
+ 4. Copy updated parameters from fp32 weights to fp16 model.
+
+ Refer to https://arxiv.org/abs/1710.03740 for more details.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of LossScaler.
+ Defaults to 512.
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ if loss_scale == 'dynamic':
+ self.loss_scaler = LossScaler(mode='dynamic')
+ elif isinstance(loss_scale, float):
+ self.loss_scaler = LossScaler(
+ init_scale=loss_scale, mode='static')
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = LossScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training.
+
+ 1. Make a master copy of fp32 weights for optimization.
+ 2. Convert the main model from fp32 to fp16.
+ """
+ # keep a copy of fp32 weights
+ old_groups = runner.optimizer.param_groups
+ runner.optimizer.param_groups = copy.deepcopy(
+ runner.optimizer.param_groups)
+ state = defaultdict(dict)
+ p_map = {
+ old_p: p
+ for old_p, p in zip(
+ chain(*(g['params'] for g in old_groups)),
+ chain(*(g['params']
+ for g in runner.optimizer.param_groups)))
+ }
+ for k, v in runner.optimizer.state.items():
+ state[p_map[k]] = v
+ runner.optimizer.state = state
+ # convert model to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer `loss_scalar.py`
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients (fp16).
+ 3. Copy gradients from the model to the fp32 weight copy.
+ 4. Scale the gradients back and update the fp32 weight copy.
+ 5. Copy back the params from fp32 weight copy to the fp16 model.
+ 6. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ # scale the loss value
+ scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+ # copy fp16 grads in the model to fp32 params in the optimizer
+
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ self.loss_scaler.update_scale(has_overflow)
+ if has_overflow:
+ runner.logger.warning('Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+ iters gradient cumulating."""
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ # scale the loss value
+ scaled_loss = loss * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ else:
+ runner.logger.warning(
+ 'Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ self.loss_scaler.update_scale(has_overflow)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/profiler.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/profiler.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from ..dist_utils import master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ProfilerHook(Hook):
+ """Profiler to analyze performance during training.
+
+ PyTorch Profiler is a tool that allows the collection of the performance
+ metrics during the training. More details on Profiler can be found at
+ https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
+
+ Args:
+ by_epoch (bool): Profile performance by epoch or by iteration.
+ Default: True.
+ profile_iters (int): Number of iterations for profiling.
+ If ``by_epoch=True``, profile_iters indicates that they are the
+ first profile_iters epochs at the beginning of the
+ training, otherwise it indicates the first profile_iters
+ iterations. Default: 1.
+ activities (list[str]): List of activity groups (CPU, CUDA) to use in
+ profiling. Default: ['cpu', 'cuda'].
+ schedule (dict, optional): Config of generating the callable schedule.
+ if schedule is None, profiler will not add step markers into the
+ trace and table view. Default: None.
+ on_trace_ready (callable, dict): Either a handler or a dict of generate
+ handler. Default: None.
+ record_shapes (bool): Save information about operator's input shapes.
+ Default: False.
+ profile_memory (bool): Track tensor memory allocation/deallocation.
+ Default: False.
+ with_stack (bool): Record source information (file and line number)
+ for the ops. Default: False.
+ with_flops (bool): Use formula to estimate the FLOPS of specific
+ operators (matrix multiplication and 2D convolution).
+ Default: False.
+ json_trace_path (str, optional): Exports the collected trace in Chrome
+ JSON format. Default: None.
+
+ Example:
+ >>> runner = ... # instantiate a Runner
+ >>> # tensorboard trace
+ >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
+ >>> profiler_config = dict(on_trace_ready=trace_config)
+ >>> runner.register_profiler_hook(profiler_config)
+ >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
+ """
+
+ def __init__(self,
+ by_epoch: bool = True,
+ profile_iters: int = 1,
+ activities: List[str] = ['cpu', 'cuda'],
+ schedule: Optional[dict] = None,
+ on_trace_ready: Optional[Union[Callable, dict]] = None,
+ record_shapes: bool = False,
+ profile_memory: bool = False,
+ with_stack: bool = False,
+ with_flops: bool = False,
+ json_trace_path: Optional[str] = None) -> None:
+ try:
+ from torch import profiler # torch version >= 1.8.1
+ except ImportError:
+ raise ImportError('profiler is the new feature of torch1.8.1, '
+ f'but your version is {torch.__version__}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
+ self.by_epoch = by_epoch
+
+ if profile_iters < 1:
+ raise ValueError('profile_iters should be greater than 0, but got '
+ f'{profile_iters}')
+ self.profile_iters = profile_iters
+
+ if not isinstance(activities, list):
+ raise ValueError(
+ f'activities should be list, but got {type(activities)}')
+ self.activities = []
+ for activity in activities:
+ activity = activity.lower()
+ if activity == 'cpu':
+ self.activities.append(profiler.ProfilerActivity.CPU)
+ elif activity == 'cuda':
+ self.activities.append(profiler.ProfilerActivity.CUDA)
+ else:
+ raise ValueError(
+ f'activity should be "cpu" or "cuda", but got {activity}')
+
+ if schedule is not None:
+ self.schedule = profiler.schedule(**schedule)
+ else:
+ self.schedule = None
+
+ self.on_trace_ready = on_trace_ready
+ self.record_shapes = record_shapes
+ self.profile_memory = profile_memory
+ self.with_stack = with_stack
+ self.with_flops = with_flops
+ self.json_trace_path = json_trace_path
+
+ @master_only
+ def before_run(self, runner):
+ if self.by_epoch and runner.max_epochs < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_epochs}')
+
+ if not self.by_epoch and runner.max_iters < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_iters}')
+
+ if callable(self.on_trace_ready): # handler
+ _on_trace_ready = self.on_trace_ready
+ elif isinstance(self.on_trace_ready, dict): # config of handler
+ trace_cfg = self.on_trace_ready.copy()
+ trace_type = trace_cfg.pop('type') # log_trace handler
+ if trace_type == 'log_trace':
+
+ def _log_handler(prof):
+ print(prof.key_averages().table(**trace_cfg))
+
+ _on_trace_ready = _log_handler
+ elif trace_type == 'tb_trace': # tensorboard_trace handler
+ try:
+ import torch_tb_profiler # noqa: F401
+ except ImportError:
+ raise ImportError('please run "pip install '
+ 'torch-tb-profiler" to install '
+ 'torch_tb_profiler')
+ _on_trace_ready = torch.profiler.tensorboard_trace_handler(
+ **trace_cfg)
+ else:
+ raise ValueError('trace_type should be "log_trace" or '
+ f'"tb_trace", but got {trace_type}')
+ elif self.on_trace_ready is None:
+ _on_trace_ready = None # type: ignore
+ else:
+ raise ValueError('on_trace_ready should be handler, dict or None, '
+ f'but got {type(self.on_trace_ready)}')
+
+ if runner.max_epochs > 1:
+ warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
+ 'instead of 1 epoch. Since profiler will slow down '
+ 'the training, it is recommended to train 1 epoch '
+ 'with ProfilerHook and adjust your setting according'
+ ' to the profiler summary. During normal training '
+ '(epoch > 1), you may disable the ProfilerHook.')
+
+ self.profiler = torch.profiler.profile(
+ activities=self.activities,
+ schedule=self.schedule,
+ on_trace_ready=_on_trace_ready,
+ record_shapes=self.record_shapes,
+ profile_memory=self.profile_memory,
+ with_stack=self.with_stack,
+ with_flops=self.with_flops)
+
+ self.profiler.__enter__()
+ runner.logger.info('profiler is profiling...')
+
+ @master_only
+ def after_train_epoch(self, runner):
+ if self.by_epoch and runner.epoch == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
+
+ @master_only
+ def after_train_iter(self, runner):
+ self.profiler.step()
+ if not self.by_epoch and runner.iter == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/sampler_seed.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/sampler_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/sampler_seed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class DistSamplerSeedHook(Hook):
+ """Data-loading sampler for distributed training.
+
+ When distributed training, it is only useful in conjunction with
+ :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
+ purpose with :obj:`IterLoader`.
+ """
+
+ def before_epoch(self, runner):
+ if hasattr(runner.data_loader.sampler, 'set_epoch'):
+ # in case the data loader uses `SequentialSampler` in Pytorch
+ runner.data_loader.sampler.set_epoch(runner.epoch)
+ elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+ # batch sampler in pytorch warps the sampler as its attributes.
+ runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/hooks/sync_buffer.py b/src/custom_mmpkg/custom_mmcv/runner/hooks/sync_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/hooks/sync_buffer.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..dist_utils import allreduce_params
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class SyncBuffersHook(Hook):
+ """Synchronize model buffers such as running_mean and running_var in BN at
+ the end of each epoch.
+
+ Args:
+ distributed (bool): Whether distributed training is used. It is
+ effective only for distributed training. Defaults to True.
+ """
+
+ def __init__(self, distributed=True):
+ self.distributed = distributed
+
+ def after_epoch(self, runner):
+ """All-reduce model buffers at the end of each epoch."""
+ if self.distributed:
+ allreduce_params(runner.model.buffers())
diff --git a/src/custom_mmpkg/custom_mmcv/runner/iter_based_runner.py b/src/custom_mmpkg/custom_mmcv/runner/iter_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..075e4b93996c7e5c267a1cd01afd439a5ac06e53
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/iter_based_runner.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+from torch.optim import Optimizer
+
+import custom_mmpkg.custom_mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .hooks import IterTimerHook
+from .utils import get_host_info
+
+
+class IterLoader:
+
+ def __init__(self, dataloader):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._epoch = 0
+
+ @property
+ def epoch(self):
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, 'set_epoch'):
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __len__(self):
+ return len(self._dataloader)
+
+
+@RUNNERS.register_module()
+class IterBasedRunner(BaseRunner):
+ """Iteration-based Runner.
+
+ This runner train models iteration by iteration.
+ """
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._epoch = data_loader.epoch
+ data_batch = next(data_loader)
+ self.call_hook('before_train_iter')
+ outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.train_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_train_iter')
+ self._inner_iter += 1
+ self._iter += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ data_batch = next(data_loader)
+ self.call_hook('before_val_iter')
+ outputs = self.model.val_step(data_batch, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.val_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_val_iter')
+ self._inner_iter += 1
+
+ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, iters) to specify the
+ running order and iterations. E.g, [('train', 10000),
+ ('val', 1000)] means running 10000 iterations for training and
+ 1000 iterations for validation, iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_iters is not None:
+ warnings.warn(
+ 'setting max_iters in run is deprecated, '
+ 'please set max_iters in runner_config', DeprecationWarning)
+ self._max_iters = max_iters
+ assert self._max_iters is not None, (
+ 'max_iters must be specified during instantiation')
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d iters', workflow,
+ self._max_iters)
+ self.call_hook('before_run')
+
+ iter_loaders = [IterLoader(x) for x in data_loaders]
+
+ self.call_hook('before_epoch')
+
+ while self.iter < self._max_iters:
+ for i, flow in enumerate(workflow):
+ self._inner_iter = 0
+ mode, iters = flow
+ if not isinstance(mode, str) or not hasattr(self, mode):
+ raise ValueError(
+ 'runner has no method named "{}" to run a workflow'.
+ format(mode))
+ iter_runner = getattr(self, mode)
+ for _ in range(iters):
+ if mode == 'train' and self.iter >= self._max_iters:
+ break
+ iter_runner(iter_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_epoch')
+ self.call_hook('after_run')
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ """Resume model from checkpoint.
+
+ Args:
+ checkpoint (str): Checkpoint to resume from.
+ resume_optimizer (bool, optional): Whether resume the optimizer(s)
+ if the checkpoint file includes optimizer(s). Default to True.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default to 'default'.
+ """
+ if map_location == 'default':
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ self._inner_iter = checkpoint['meta']['iter']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='iter_{}.pth',
+ meta=None,
+ save_optimizer=True,
+ create_symlink=True):
+ """Save checkpoint to file.
+
+ Args:
+ out_dir (str): Directory to save checkpoint files.
+ filename_tmpl (str, optional): Checkpoint file template.
+ Defaults to 'iter_{}.pth'.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ Defaults to None.
+ save_optimizer (bool, optional): Whether save optimizer.
+ Defaults to True.
+ create_symlink (bool, optional): Whether create symlink to the
+ latest checkpoint file. Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.iter + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ custom_hooks_config=None):
+ """Register default hooks for iter-based training.
+
+ Checkpoint hook, optimizer stepper hook and logger hooks will be set to
+ `by_epoch=False` by default.
+
+ Default hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ if checkpoint_config is not None:
+ checkpoint_config.setdefault('by_epoch', False)
+ if lr_config is not None:
+ lr_config.setdefault('by_epoch', False)
+ if log_config is not None:
+ for info in log_config['hooks']:
+ info.setdefault('by_epoch', False)
+ super(IterBasedRunner, self).register_training_hooks(
+ lr_config=lr_config,
+ momentum_config=momentum_config,
+ optimizer_config=optimizer_config,
+ checkpoint_config=checkpoint_config,
+ log_config=log_config,
+ timer_config=IterTimerHook(),
+ custom_hooks_config=custom_hooks_config)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/log_buffer.py b/src/custom_mmpkg/custom_mmcv/runner/log_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/log_buffer.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import numpy as np
+
+
+class LogBuffer:
+
+ def __init__(self):
+ self.val_history = OrderedDict()
+ self.n_history = OrderedDict()
+ self.output = OrderedDict()
+ self.ready = False
+
+ def clear(self):
+ self.val_history.clear()
+ self.n_history.clear()
+ self.clear_output()
+
+ def clear_output(self):
+ self.output.clear()
+ self.ready = False
+
+ def update(self, vars, count=1):
+ assert isinstance(vars, dict)
+ for key, var in vars.items():
+ if key not in self.val_history:
+ self.val_history[key] = []
+ self.n_history[key] = []
+ self.val_history[key].append(var)
+ self.n_history[key].append(count)
+
+ def average(self, n=0):
+ """Average latest n values or all values."""
+ assert n >= 0
+ for key in self.val_history:
+ values = np.array(self.val_history[key][-n:])
+ nums = np.array(self.n_history[key][-n:])
+ avg = np.sum(values * nums) / np.sum(nums)
+ self.output[key] = avg
+ self.ready = True
diff --git a/src/custom_mmpkg/custom_mmcv/runner/optimizer/__init__.py b/src/custom_mmpkg/custom_mmcv/runner/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/optimizer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
+ build_optimizer_constructor)
+from .default_constructor import DefaultOptimizerConstructor
+
+__all__ = [
+ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
+ 'build_optimizer', 'build_optimizer_constructor'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/runner/optimizer/builder.py b/src/custom_mmpkg/custom_mmcv/runner/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/optimizer/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+
+import torch
+
+from ...utils import Registry, build_from_cfg
+
+OPTIMIZERS = Registry('optimizer')
+OPTIMIZER_BUILDERS = Registry('optimizer builder')
+
+
+def register_torch_optimizers():
+ torch_optimizers = []
+ for module_name in dir(torch.optim):
+ if module_name.startswith('__'):
+ continue
+ _optim = getattr(torch.optim, module_name)
+ if inspect.isclass(_optim) and issubclass(_optim,
+ torch.optim.Optimizer):
+ OPTIMIZERS.register_module()(_optim)
+ torch_optimizers.append(module_name)
+ return torch_optimizers
+
+
+TORCH_OPTIMIZERS = register_torch_optimizers()
+
+
+def build_optimizer_constructor(cfg):
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+
+
+def build_optimizer(model, cfg):
+ optimizer_cfg = copy.deepcopy(cfg)
+ constructor_type = optimizer_cfg.pop('constructor',
+ 'DefaultOptimizerConstructor')
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+ optim_constructor = build_optimizer_constructor(
+ dict(
+ type=constructor_type,
+ optimizer_cfg=optimizer_cfg,
+ paramwise_cfg=paramwise_cfg))
+ optimizer = optim_constructor(model)
+ return optimizer
diff --git a/src/custom_mmpkg/custom_mmcv/runner/optimizer/default_constructor.py b/src/custom_mmpkg/custom_mmcv/runner/optimizer/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5901955857ab2d650907a284312c0a989de7b9a7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/optimizer/default_constructor.py
@@ -0,0 +1,249 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+from torch.nn import GroupNorm, LayerNorm
+
+from custom_mmpkg.custom_mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
+from custom_mmpkg.custom_mmcv.utils.ext_loader import check_ops_exist
+from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class DefaultOptimizerConstructor:
+ """Default constructor for optimizers.
+
+ By default each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain the following fields:
+
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+ be ignored. It should be noted that the aforementioned ``key`` is the
+ longest key that is a substring of the name of the parameter. If there
+ are multiple matched keys with the same length, then the key with lower
+ alphabet order will be chosen.
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+ and ``decay_mult``. See Example 2 below.
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
+ rate for all bias parameters (except for those in normalization
+ layers and offset layers of DCN).
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in
+ normalization layers, depthwise conv layers, offset layers of DCN).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization
+ layers.
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of depthwise conv
+ layers.
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+ rate for parameters of offset layer in the deformable convs
+ of a model.
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+ would not be added into optimizer. Default: False.
+
+ Note:
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ override the effect of ``bias_lr_mult`` in the bias of offset
+ layer. So be careful when using both ``bias_lr_mult`` and
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the
+ offset layer in deformable convs, set ``dcn_offset_lr_mult``
+ to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ apply it to all the DCN layers in the model. So be careful when
+ the model contains multiple DCN layers in places other than
+ backbone.
+
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are
+
+ - `type`: class name of the optimizer.
+
+ Optional fields are
+
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+ paramwise_cfg (dict, optional): Parameter-wise options.
+
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+ >>> weight_decay=0.0001)
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+
+ Example 2:
+ >>> # assume model have attribute model.backbone and model.cls_head
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
+ >>> paramwise_cfg = dict(custom_keys={
+ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+ >>> # model.cls_head is (0.01, 0.95).
+ """
+
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
+ if not isinstance(optimizer_cfg, dict):
+ raise TypeError('optimizer_cfg should be a dict',
+ f'but got {type(optimizer_cfg)}')
+ self.optimizer_cfg = optimizer_cfg
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+ self.base_lr = optimizer_cfg.get('lr', None)
+ self.base_wd = optimizer_cfg.get('weight_decay', None)
+ self._validate_cfg()
+
+ def _validate_cfg(self):
+ if not isinstance(self.paramwise_cfg, dict):
+ raise TypeError('paramwise_cfg should be None or a dict, '
+ f'but got {type(self.paramwise_cfg)}')
+
+ if 'custom_keys' in self.paramwise_cfg:
+ if not isinstance(self.paramwise_cfg['custom_keys'], dict):
+ raise TypeError(
+ 'If specified, custom_keys must be a dict, '
+ f'but got {type(self.paramwise_cfg["custom_keys"])}')
+ if self.base_wd is None:
+ for key in self.paramwise_cfg['custom_keys']:
+ if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
+ raise ValueError('base_wd should not be None')
+
+ # get base lr and weight decay
+ # weight_decay must be explicitly specified if mult is specified
+ if ('bias_decay_mult' in self.paramwise_cfg
+ or 'norm_decay_mult' in self.paramwise_cfg
+ or 'dwconv_decay_mult' in self.paramwise_cfg):
+ if self.base_wd is None:
+ raise ValueError('base_wd should not be None')
+
+ def _is_in(self, param_group, param_group_list):
+ assert is_list_of(param_group_list, dict)
+ param = set(param_group['params'])
+ param_set = set()
+ for group in param_group_list:
+ param_set.update(set(group['params']))
+
+ return not param.isdisjoint(param_set)
+
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ prefix (str): The prefix of the module
+ is_dcn_module (int|float|None): If the current module is a
+ submodule of DCN, `is_dcn_module` will be passed to
+ control conv_offset layer's learning rate. Defaults to None.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
+
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+ is_dwconv = (
+ isinstance(module, torch.nn.Conv2d)
+ and module.in_channels == module.groups)
+
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+ if bypass_duplicate and self._is_in(param_group, params):
+ warnings.warn(f'{prefix} is duplicate. It is skipped since '
+ f'bypass_duplicate={bypass_duplicate}')
+ continue
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ for key in sorted_keys:
+ if key in f'{prefix}.{name}':
+ is_custom = True
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
+ param_group['lr'] = self.base_lr * lr_mult
+ if self.base_wd is not None:
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
+ param_group['weight_decay'] = self.base_wd * decay_mult
+ break
+
+ if not is_custom:
+ # bias_lr_mult affects all bias parameters
+ # except for norm.bias dcn.conv_offset.bias
+ if name == 'bias' and not (is_norm or is_dcn_module):
+ param_group['lr'] = self.base_lr * bias_lr_mult
+
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
+ and isinstance(module, torch.nn.Conv2d)):
+ # deal with both dcn_offset's bias & weight
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+
+ # apply weight decay policies
+ if self.base_wd is not None:
+ # norm decay
+ if is_norm:
+ param_group[
+ 'weight_decay'] = self.base_wd * norm_decay_mult
+ # depth-wise conv
+ elif is_dwconv:
+ param_group[
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
+ # bias lr and decay
+ elif name == 'bias' and not is_dcn_module:
+ # TODO: current bias_decay_mult will have affect on DCN
+ param_group[
+ 'weight_decay'] = self.base_wd * bias_decay_mult
+ params.append(param_group)
+
+ if check_ops_exist():
+ from custom_mmpkg.custom_mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+ is_dcn_module = isinstance(module,
+ (DeformConv2d, ModulatedDeformConv2d))
+ else:
+ is_dcn_module = False
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ is_dcn_module=is_dcn_module)
+
+ def __call__(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+
+ optimizer_cfg = self.optimizer_cfg.copy()
+ # if no paramwise option is specified, just use the global setting
+ if not self.paramwise_cfg:
+ optimizer_cfg['params'] = model.parameters()
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+
+ # set param-wise lr and weight decay recursively
+ params = []
+ self.add_params(params, model)
+ optimizer_cfg['params'] = params
+
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
diff --git a/src/custom_mmpkg/custom_mmcv/runner/priority.py b/src/custom_mmpkg/custom_mmcv/runner/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/priority.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+
+class Priority(Enum):
+ """Hook priority levels.
+
+ +--------------+------------+
+ | Level | Value |
+ +==============+============+
+ | HIGHEST | 0 |
+ +--------------+------------+
+ | VERY_HIGH | 10 |
+ +--------------+------------+
+ | HIGH | 30 |
+ +--------------+------------+
+ | ABOVE_NORMAL | 40 |
+ +--------------+------------+
+ | NORMAL | 50 |
+ +--------------+------------+
+ | BELOW_NORMAL | 60 |
+ +--------------+------------+
+ | LOW | 70 |
+ +--------------+------------+
+ | VERY_LOW | 90 |
+ +--------------+------------+
+ | LOWEST | 100 |
+ +--------------+------------+
+ """
+
+ HIGHEST = 0
+ VERY_HIGH = 10
+ HIGH = 30
+ ABOVE_NORMAL = 40
+ NORMAL = 50
+ BELOW_NORMAL = 60
+ LOW = 70
+ VERY_LOW = 90
+ LOWEST = 100
+
+
+def get_priority(priority):
+ """Get priority value.
+
+ Args:
+ priority (int or str or :obj:`Priority`): Priority.
+
+ Returns:
+ int: The priority value.
+ """
+ if isinstance(priority, int):
+ if priority < 0 or priority > 100:
+ raise ValueError('priority must be between 0 and 100')
+ return priority
+ elif isinstance(priority, Priority):
+ return priority.value
+ elif isinstance(priority, str):
+ return Priority[priority.upper()].value
+ else:
+ raise TypeError('priority must be an integer or Priority enum value')
diff --git a/src/custom_mmpkg/custom_mmcv/runner/utils.py b/src/custom_mmpkg/custom_mmcv/runner/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fa4a7297f2cb10f7f2824470434aa34d8de0bb
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/runner/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+import sys
+import time
+import warnings
+from getpass import getuser
+from socket import gethostname
+
+import numpy as np
+import torch
+
+import custom_mmpkg.custom_mmcv as mmcv
+
+
+def get_host_info():
+ """Get hostname and username.
+
+ Return empty string if exception raised, e.g. ``getpass.getuser()`` will
+ lead to error in docker container
+ """
+ host = ''
+ try:
+ host = f'{getuser()}@{gethostname()}'
+ except Exception as e:
+ warnings.warn(f'Host or user not found: {str(e)}')
+ finally:
+ return host
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def obj_from_dict(info, parent=None, default_args=None):
+ """Initialize an object from dict.
+
+ The dict must contain the key "type", which indicates the object type, it
+ can be either a string or type, such as "list" or ``list``. Remaining
+ fields are treated as the arguments for constructing the object.
+
+ Args:
+ info (dict): Object types and arguments.
+ parent (:class:`module`): Module which may containing expected object
+ classes.
+ default_args (dict, optional): Default arguments for initializing the
+ object.
+
+ Returns:
+ any type: Object built from the dict.
+ """
+ assert isinstance(info, dict) and 'type' in info
+ assert isinstance(default_args, dict) or default_args is None
+ args = info.copy()
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if parent is not None:
+ obj_type = getattr(parent, obj_type)
+ else:
+ obj_type = sys.modules[obj_type]
+ elif not isinstance(obj_type, type):
+ raise TypeError('type must be a str or valid type, but '
+ f'got {type(obj_type)}')
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_type(**args)
+
+
+def set_random_seed(seed, deterministic=False, use_rank_shift=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ rank_shift (bool): Whether to add rank number to the random seed to
+ have different random seed in different threads. Default: False.
+ """
+ if use_rank_shift:
+ rank, _ = mmcv.runner.get_dist_info()
+ seed += rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/src/custom_mmpkg/custom_mmcv/utils/__init__.py b/src/custom_mmpkg/custom_mmcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/__init__.py
@@ -0,0 +1,69 @@
+# flake8: noqa
+# Copyright (c) OpenMMLab. All rights reserved.
+from .config import Config, ConfigDict, DictAction
+from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
+ has_method, import_modules_from_strings, is_list_of,
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
+ iter_cast, list_cast, requires_executable, requires_package,
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+ to_ntuple, tuple_cast)
+from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
+ scandir, symlink)
+from .progressbar import (ProgressBar, track_iter_progress,
+ track_parallel_progress, track_progress)
+from .testing import (assert_attrs_equal, assert_dict_contains_subset,
+ assert_dict_has_keys, assert_is_norm_layer,
+ assert_keys_equal, assert_params_all_zeros,
+ check_python_script)
+from .timer import Timer, TimerError, check_time
+from .version_utils import digit_version, get_git_hash
+
+try:
+ import torch
+except ImportError:
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
+ 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
+ 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
+ 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
+ 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
+ 'track_progress', 'track_iter_progress', 'track_parallel_progress',
+ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
+ 'digit_version', 'get_git_hash', 'import_modules_from_strings',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+ 'is_method_overridden', 'has_method'
+ ]
+else:
+ from .env import collect_env
+ from .logging import get_logger, print_log
+ from .parrots_jit import jit, skip_no_elena
+ from .parrots_wrapper import (
+ TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
+ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
+ _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
+ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
+ from .registry import Registry, build_from_cfg
+ from .trace import is_jit_tracing
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
+ 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
+ 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
+ 'check_prerequisites', 'requires_package', 'requires_executable',
+ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
+ 'symlink', 'scandir', 'ProgressBar', 'track_progress',
+ 'track_iter_progress', 'track_parallel_progress', 'Registry',
+ 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
+ '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
+ '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
+ 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
+ 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
+ 'deprecated_api_warning', 'digit_version', 'get_git_hash',
+ 'import_modules_from_strings', 'jit', 'skip_no_elena',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
+ 'assert_params_all_zeros', 'check_python_script',
+ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
+ '_get_cuda_home', 'has_method'
+ ]
diff --git a/src/custom_mmpkg/custom_mmcv/utils/config.py b/src/custom_mmpkg/custom_mmcv/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..098a706764a1c18fee26bdaae6d5898d9af23282
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/config.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+DEPRECATION_KEY = '_deprecation_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+
+
+class ConfigDict(Dict):
+
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=''):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument('--' + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument('--' + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument('--' + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument('--' + prefix + k, action='store_true')
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + '.')
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+ else:
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}: {e}')
+
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname)
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
+ value = value.replace('\\', '/')
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+ base_var_dict[randstr] = base_var
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split('.'):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(
+ v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg)
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split('.'):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname)
+ if platform.system() == 'Windows':
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename,
+ temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name)
+
+ if filename.endswith('.py'):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith(('.yml', '.yaml', '.json')):
+ import custom_mmpkg.custom_mmcv as mmcv
+ cfg_dict = mmcv.load(temp_config_file.name)
+ # close temp file
+ temp_config_file.close()
+
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = f'The config file {filename} will be deprecated ' \
+ 'in the future.'
+ if 'expected' in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
+ 'instead.'
+ if 'reference' in deprecation_info:
+ warning_msg += ' More information can be found at ' \
+ f'{deprecation_info["reference"]}'
+ warnings.warn(warning_msg)
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError('Duplicate key is not allowed among bases. '
+ f'Duplicate keys: {duplicate_keys}')
+ base_cfg_dict.update(c)
+
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+ base_cfg_dict)
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v,
+ dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename,
+ use_predefined_variables=True,
+ import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename,
+ use_predefined_variables)
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
+ import_modules_from_strings(**cfg_dict['custom_imports'])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+ if file_format != '.py' and 'dict(' in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn(
+ 'Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ 'w', encoding='utf-8', suffix=file_format,
+ delete=False) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument('config', help='config file path')
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument('config', help='config file path')
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(Config, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(Config, self).__setattr__('_text', text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
+ super(Config, self).__setattr__('_filename', _filename)
+ super(Config, self).__setattr__('_text', _text)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+ if self.filename.endswith('.py'):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w', encoding='utf-8') as f:
+ f.write(self.pretty_text)
+ else:
+ import custom_mmpkg.custom_mmcv as mmcv
+ if file is None:
+ file_format = self.filename.split('.')[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ super(Config, self).__setattr__(
+ '_cfg_dict',
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count('(') == string.count(')')) and (
+ string.count('[') == string.count(']')), \
+ f'Imbalanced brackets exist in {string}'
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
+ and (pre.count('[') == pre.count(']'))):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip('\'\"').replace(' ', '')
+ is_tuple = False
+ if val.startswith('(') and val.endswith(')'):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith('[') and val.endswith(']'):
+ val = val[1:-1]
+ elif ',' not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1:]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/src/custom_mmpkg/custom_mmcv/utils/env.py b/src/custom_mmpkg/custom_mmcv/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffc2e44d2d272d81c74fb2333849265011cd5fec
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+
+import os.path as osp
+import subprocess
+import sys
+from collections import defaultdict
+
+import cv2
+import torch
+
+import custom_mmpkg.custom_mmcv as mmcv
+from .parrots_wrapper import get_build_config
+
+
+def collect_env():
+ """Collect the information of the running environments.
+
+ Returns:
+ dict: The environment information. The following fields are contained.
+
+ - sys.platform: The variable of ``sys.platform``.
+ - Python: Python version.
+ - CUDA available: Bool, indicating if CUDA is available.
+ - GPU devices: Device type of each GPU.
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+ - NVCC (optional): NVCC version.
+ - GCC: GCC version, "n/a" if GCC is not installed.
+ - PyTorch: PyTorch version.
+ - PyTorch compiling details: The output of \
+ ``torch.__config__.show()``.
+ - TorchVision (optional): TorchVision version.
+ - OpenCV: OpenCV version.
+ - MMCV: MMCV version.
+ - MMCV Compiler: The GCC version for compiling MMCV ops.
+ - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
+ """
+ env_info = {}
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+
+ cuda_available = torch.cuda.is_available()
+ env_info['CUDA available'] = cuda_available
+
+ if cuda_available:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+
+ from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _get_cuda_home
+ CUDA_HOME = _get_cuda_home()
+ env_info['CUDA_HOME'] = CUDA_HOME
+
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+ try:
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(
+ f'"{nvcc}" -V | tail -n1', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+
+ try:
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+ gcc = gcc.decode('utf-8').strip()
+ env_info['GCC'] = gcc
+ except subprocess.CalledProcessError: # gcc is unavailable
+ env_info['GCC'] = 'n/a'
+
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = get_build_config()
+
+ try:
+ import torchvision
+ env_info['TorchVision'] = torchvision.__version__
+ except ModuleNotFoundError:
+ pass
+
+ env_info['OpenCV'] = cv2.__version__
+
+ env_info['MMCV'] = mmcv.__version__
+
+ try:
+ from custom_mmpkg.custom_mmcv.ops import get_compiler_version, get_compiling_cuda_version
+ except ModuleNotFoundError:
+ env_info['MMCV Compiler'] = 'n/a'
+ env_info['MMCV CUDA Compiler'] = 'n/a'
+ else:
+ env_info['MMCV Compiler'] = get_compiler_version()
+ env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
+
+ return env_info
diff --git a/src/custom_mmpkg/custom_mmcv/utils/ext_loader.py b/src/custom_mmpkg/custom_mmcv/utils/ext_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fbdddf85818f8c6f2fb8b121c9fdc26259a64b8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/ext_loader.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import pkgutil
+import warnings
+from collections import namedtuple
+
+import torch
+
+if torch.__version__ != 'parrots':
+
+ def load_ext(name, funcs):
+ ext = importlib.import_module('custom_mmcv.' + name)
+ for fun in funcs:
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
+ return ext
+else:
+ from parrots import extension
+ from parrots.base import ParrotsException
+
+ has_return_value_ops = [
+ 'nms',
+ 'softnms',
+ 'nms_match',
+ 'nms_rotated',
+ 'top_pool_forward',
+ 'top_pool_backward',
+ 'bottom_pool_forward',
+ 'bottom_pool_backward',
+ 'left_pool_forward',
+ 'left_pool_backward',
+ 'right_pool_forward',
+ 'right_pool_backward',
+ 'fused_bias_leakyrelu',
+ 'upfirdn2d',
+ 'ms_deform_attn_forward',
+ 'pixel_group',
+ 'contour_expand',
+ ]
+
+ def get_fake_func(name, e):
+
+ def fake_func(*args, **kwargs):
+ warnings.warn(f'{name} is not supported in parrots now')
+ raise e
+
+ return fake_func
+
+ def load_ext(name, funcs):
+ ExtModule = namedtuple('ExtModule', funcs)
+ ext_list = []
+ lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ for fun in funcs:
+ try:
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
+ except ParrotsException as e:
+ if 'No element registered' not in e.message:
+ warnings.warn(e.message)
+ ext_fun = get_fake_func(fun, e)
+ ext_list.append(ext_fun)
+ else:
+ if fun in has_return_value_ops:
+ ext_list.append(ext_fun.op)
+ else:
+ ext_list.append(ext_fun.op_)
+ return ExtModule(*ext_list)
+
+
+def check_ops_exist():
+ ext_loader = pkgutil.find_loader('mmcv._ext')
+ return ext_loader is not None
diff --git a/src/custom_mmpkg/custom_mmcv/utils/logging.py b/src/custom_mmpkg/custom_mmcv/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/logging.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.distributed as dist
+
+logger_initialized = {}
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == 'silent':
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ 'logger should be either a logging.Logger object, str, '
+ f'"silent" or None, but got {type(logger)}')
diff --git a/src/custom_mmpkg/custom_mmcv/utils/misc.py b/src/custom_mmpkg/custom_mmcv/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/misc.py
@@ -0,0 +1,377 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections.abc
+import functools
+import itertools
+import subprocess
+import warnings
+from collections import abc
+from importlib import import_module
+from inspect import getfullargspec
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+
+ Returns:
+ list[module] | module | None: The imported modules.
+
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(
+ f'custom_imports must be a list but got type {type(imports)}')
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(
+ f'{imp} is of type {type(imp)} and cannot be imported.')
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f'{imp} failed to import and is ignored.',
+ UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
+
+
+def iter_cast(inputs, dst_type, return_type=None):
+ """Cast elements of an iterable object into some type.
+
+ Args:
+ inputs (Iterable): The input object.
+ dst_type (type): Destination type.
+ return_type (type, optional): If specified, the output object will be
+ converted to this type, otherwise an iterator.
+
+ Returns:
+ iterator or specified type: The converted object.
+ """
+ if not isinstance(inputs, abc.Iterable):
+ raise TypeError('inputs must be an iterable object')
+ if not isinstance(dst_type, type):
+ raise TypeError('"dst_type" must be a valid type')
+
+ out_iterable = map(dst_type, inputs)
+
+ if return_type is None:
+ return out_iterable
+ else:
+ return return_type(out_iterable)
+
+
+def list_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a list of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=list)
+
+
+def tuple_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a tuple of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=tuple)
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_list_of(seq, expected_type):
+ """Check whether it is a list of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=list)
+
+
+def is_tuple_of(seq, expected_type):
+ """Check whether it is a tuple of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=tuple)
+
+
+def slice_list(in_list, lens):
+ """Slice a list into several sub lists by a list of given length.
+
+ Args:
+ in_list (list): The list to be sliced.
+ lens(int or list): The expected length of each out list.
+
+ Returns:
+ list: A list of sliced list.
+ """
+ if isinstance(lens, int):
+ assert len(in_list) % lens == 0
+ lens = [lens] * int(len(in_list) / lens)
+ if not isinstance(lens, list):
+ raise TypeError('"indices" must be an integer or a list of integers')
+ elif sum(lens) != len(in_list):
+ raise ValueError('sum of lens and list length does not '
+ f'match: {sum(lens)} != {len(in_list)}')
+ out_list = []
+ idx = 0
+ for i in range(len(lens)):
+ out_list.append(in_list[idx:idx + lens[i]])
+ idx += lens[i]
+ return out_list
+
+
+def concat_list(in_list):
+ """Concatenate a list of list into a single list.
+
+ Args:
+ in_list (list): The list of list to be merged.
+
+ Returns:
+ list: The concatenated flat list.
+ """
+ return list(itertools.chain(*in_list))
+
+
+def check_prerequisites(
+ prerequisites,
+ checker,
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+ 'found, please install them first.'): # yapf: disable
+ """A decorator factory to check if prerequisites are satisfied.
+
+ Args:
+ prerequisites (str of list[str]): Prerequisites to be checked.
+ checker (callable): The checker method that returns True if a
+ prerequisite is meet, False otherwise.
+ msg_tmpl (str): The message template with two variables.
+
+ Returns:
+ decorator: A specific decorator.
+ """
+
+ def wrap(func):
+
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ requirements = [prerequisites] if isinstance(
+ prerequisites, str) else prerequisites
+ missing = []
+ for item in requirements:
+ if not checker(item):
+ missing.append(item)
+ if missing:
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
+ raise RuntimeError('Prerequisites not meet.')
+ else:
+ return func(*args, **kwargs)
+
+ return wrapped_func
+
+ return wrap
+
+
+def _check_py_package(package):
+ try:
+ import_module(package)
+ except ImportError:
+ return False
+ else:
+ return True
+
+
+def _check_executable(cmd):
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
+ return False
+ else:
+ return True
+
+
+def requires_package(prerequisites):
+ """A decorator to check if some python packages are installed.
+
+ Example:
+ >>> @requires_package('numpy')
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ array([0.])
+ >>> @requires_package(['numpy', 'non_package'])
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ ImportError
+ """
+ return check_prerequisites(prerequisites, checker=_check_py_package)
+
+
+def requires_executable(prerequisites):
+ """A decorator to check if some executable files are installed.
+
+ Example:
+ >>> @requires_executable('ffmpeg')
+ >>> func(arg1, args):
+ >>> print(1)
+ 1
+ """
+ return check_prerequisites(prerequisites, checker=_check_executable)
+
+
+def deprecated_api_warning(name_dict, cls_name=None):
+ """A decorator to check if some arguments are deprecate and try to replace
+ deprecate src_arg_name to dst_arg_name.
+
+ Args:
+ name_dict(dict):
+ key (str): Deprecate argument names.
+ val (str): Expected argument names.
+
+ Returns:
+ func: New function.
+ """
+
+ def api_warning_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get name of the function
+ func_name = old_func.__name__
+ if cls_name is not None:
+ func_name = f'{cls_name}.{func_name}'
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in arg_names:
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
+ if kwargs:
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in kwargs:
+
+ assert dst_arg_name not in kwargs, (
+ f'The expected behavior is to replace '
+ f'the deprecated key `{src_arg_name}` to '
+ f'new key `{dst_arg_name}`, but got them '
+ f'in the arguments at the same time, which '
+ f'is confusing. `{src_arg_name} will be '
+ f'deprecated in the future, please '
+ f'use `{dst_arg_name}` instead.')
+
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
+
+ # apply converted arguments to the decorated method
+ output = old_func(*args, **kwargs)
+ return output
+
+ return new_func
+
+ return api_warning_wrapper
+
+
+def is_method_overridden(method, base_class, derived_class):
+ """Check if a method of base class is overridden in derived class.
+
+ Args:
+ method (str): the method name to check.
+ base_class (type): the class of the base class.
+ derived_class (type | Any): the class or instance of the derived class.
+ """
+ assert isinstance(base_class, type), \
+ "base_class doesn't accept instance, Please pass class instead."
+
+ if not isinstance(derived_class, type):
+ derived_class = derived_class.__class__
+
+ base_method = getattr(base_class, method)
+ derived_method = getattr(derived_class, method)
+ return derived_method != base_method
+
+
+def has_method(obj: object, method: str) -> bool:
+ """Check whether the object has a method.
+
+ Args:
+ method (str): The method name to check.
+ obj (object): The object to check.
+
+ Returns:
+ bool: True if the object has the method else False.
+ """
+ return hasattr(obj, method) and callable(getattr(obj, method))
diff --git a/src/custom_mmpkg/custom_mmcv/utils/parrots_jit.py b/src/custom_mmpkg/custom_mmcv/utils/parrots_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/parrots_jit.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+from .parrots_wrapper import TORCH_VERSION
+
+parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
+
+if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
+ from parrots.jit import pat as jit
+else:
+
+ def jit(func=None,
+ check_input=None,
+ full_shape=True,
+ derivate=False,
+ coderize=False,
+ optimize=False):
+
+ def wrapper(func):
+
+ def wrapper_inner(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper_inner
+
+ if func is None:
+ return wrapper
+ else:
+ return func
+
+
+if TORCH_VERSION == 'parrots':
+ from parrots.utils.tester import skip_no_elena
+else:
+
+ def skip_no_elena(func):
+
+ def wrapper(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper
diff --git a/src/custom_mmpkg/custom_mmcv/utils/parrots_wrapper.py b/src/custom_mmpkg/custom_mmcv/utils/parrots_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/parrots_wrapper.py
@@ -0,0 +1,107 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import torch
+
+TORCH_VERSION = torch.__version__
+
+
+def is_rocm_pytorch() -> bool:
+ is_rocm = False
+ if TORCH_VERSION != 'parrots':
+ try:
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm = True if ((torch.version.hip is not None) and
+ (ROCM_HOME is not None)) else False
+ except ImportError:
+ pass
+ return is_rocm
+
+
+def _get_cuda_home():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import CUDA_HOME
+ else:
+ if is_rocm_pytorch():
+ from torch.utils.cpp_extension import ROCM_HOME
+ CUDA_HOME = ROCM_HOME
+ else:
+ from torch.utils.cpp_extension import CUDA_HOME
+ return CUDA_HOME
+
+
+def get_build_config():
+ if TORCH_VERSION == 'parrots':
+ from parrots.config import get_build_info
+ return get_build_info()
+ else:
+ return torch.__config__.show()
+
+
+def _get_conv():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ else:
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ return _ConvNd, _ConvTransposeMixin
+
+
+def _get_dataloader():
+ if TORCH_VERSION == 'parrots':
+ from torch.utils.data import DataLoader, PoolDataLoader
+ else:
+ from torch.utils.data import DataLoader
+ PoolDataLoader = DataLoader
+ return DataLoader, PoolDataLoader
+
+
+def _get_extension():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import BuildExtension, Extension
+ CppExtension = partial(Extension, cuda=False)
+ CUDAExtension = partial(Extension, cuda=True)
+ else:
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+ CUDAExtension)
+ return BuildExtension, CppExtension, CUDAExtension
+
+
+def _get_pool():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ else:
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
+
+
+def _get_norm():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
+ else:
+ from torch.nn.modules.instancenorm import _InstanceNorm
+ from torch.nn.modules.batchnorm import _BatchNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
+
+
+_ConvNd, _ConvTransposeMixin = _get_conv()
+DataLoader, PoolDataLoader = _get_dataloader()
+BuildExtension, CppExtension, CUDAExtension = _get_extension()
+_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
+_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
+
+
+class SyncBatchNorm(SyncBatchNorm_):
+
+ def _check_input_dim(self, input):
+ if TORCH_VERSION == 'parrots':
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input (got {input.dim()}D input)')
+ else:
+ super()._check_input_dim(input)
diff --git a/src/custom_mmpkg/custom_mmcv/utils/path.py b/src/custom_mmpkg/custom_mmcv/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/path.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError('`filepath` should be a string or a Path')
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+ item.lower() for item in suffix)
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive,
+ case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=('.git', )):
+ """Finds the root directory (including itself) of specified markers.
+
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/src/custom_mmpkg/custom_mmcv/utils/progressbar.py b/src/custom_mmpkg/custom_mmcv/utils/progressbar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/progressbar.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+
+from .timer import Timer
+
+
+class ProgressBar:
+ """A progress bar which can print the progress."""
+
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+ self.task_num = task_num
+ self.bar_width = bar_width
+ self.completed = 0
+ self.file = file
+ if start:
+ self.start()
+
+ @property
+ def terminal_width(self):
+ width, _ = get_terminal_size()
+ return width
+
+ def start(self):
+ if self.task_num > 0:
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+ 'elapsed: 0s, ETA:')
+ else:
+ self.file.write('completed: 0, elapsed: 0s')
+ self.file.flush()
+ self.timer = Timer()
+
+ def update(self, num_tasks=1):
+ assert num_tasks > 0
+ self.completed += num_tasks
+ elapsed = self.timer.since_start()
+ if elapsed > 0:
+ fps = self.completed / elapsed
+ else:
+ fps = float('inf')
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+ f'ETA: {eta:5}s'
+
+ bar_width = min(self.bar_width,
+ int(self.terminal_width - len(msg)) + 2,
+ int(self.terminal_width * 0.6))
+ bar_width = max(2, bar_width)
+ mark_width = int(bar_width * percentage)
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+ self.file.write(msg.format(bar_chars))
+ else:
+ self.file.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+ f' {fps:.1f} tasks/s')
+ self.file.flush()
+
+
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+
+ Tasks are done with a simple for-loop.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+ file=sys.stdout):
+ """Track the progress of parallel task execution with a progress bar.
+
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ pool.close()
+ pool.join()
+ return results
+
+
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+ """Track the progress of tasks iteration or enumeration with a progress
+ bar.
+
+ Tasks are yielded with a simple for-loop.
+
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ prog_bar.file.write('\n')
diff --git a/src/custom_mmpkg/custom_mmcv/utils/registry.py b/src/custom_mmpkg/custom_mmcv/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/registry.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ if default_args is None or 'type' not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f'but got {cfg}\n{default_args}')
+ if not isinstance(registry, Registry):
+ raise TypeError('registry must be an mmcv.Registry object, '
+ f'but got {type(registry)}')
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError('default_args must be a dict or None, '
+ f'but got {type(default_args)}')
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop('type')
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ f'{obj_type} is not in the {registry.name} registry')
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f'{obj_cls.__name__}: {e}')
+
+
+class Registry:
+ """A registry to map strings to classes.
+
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + \
+ f'(name={self._name}, ' \
+ f'items={self._module_dict})'
+ return format_str
+
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+
+ The name of the package where registry is defined will be returned.
+
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+
+
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split('.')
+ return split_filename[0]
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+
+ The first scope will be split from key.
+
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find('.')
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1:]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key (str): The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert registry.scope not in self.children, \
+ f'scope {registry.scope} exists in {self.name} registry'
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError('module must be a class, '
+ f'but got {type(module_class)}')
+
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f'{name} is already registered '
+ f'in {self.name}')
+ self._module_dict[name] = module_class
+
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ 'The old API of register_module(module, force=False) '
+ 'is deprecated and will be removed, please use the new API '
+ 'register_module(name=None, force=False, module=None) instead.')
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ 'name must be either of None, an instance of str or a sequence'
+ f' of str, but got {type(name)}')
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(
+ module_class=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(
+ module_class=cls, module_name=name, force=force)
+ return cls
+
+ return _register
diff --git a/src/custom_mmpkg/custom_mmcv/utils/testing.py b/src/custom_mmpkg/custom_mmcv/utils/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/testing.py
@@ -0,0 +1,140 @@
+# Copyright (c) Open-MMLab.
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Dict, List
+from unittest.mock import patch
+
+
+def check_python_script(cmd):
+ """Run the python cmd script with `__main__`. The difference between
+ `os.system` is that, this function exectues code in the current process, so
+ that it can be tracked by coverage tools. Currently it supports two forms:
+
+ - ./tests/data/scripts/hello.py zz
+ - python tests/data/scripts/hello.py zz
+ """
+ args = split(cmd)
+ if args[0] == 'python':
+ args = args[1:]
+ with patch.object(sys, 'argv', args):
+ run_path(args[0], run_name='__main__')
+
+
+def _any(judge_result):
+ """Since built-in ``any`` works only when the element of iterable is not
+ iterable, implement the function."""
+ if not isinstance(judge_result, Iterable):
+ return judge_result
+
+ try:
+ for element in judge_result:
+ if _any(element):
+ return True
+ except TypeError:
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+ if judge_result:
+ return True
+ return False
+
+
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+ expected_subset: Dict[Any, Any]) -> bool:
+ """Check if the dict_obj contains the expected_subset.
+
+ Args:
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
+ dict_obj.
+
+ Returns:
+ bool: Whether the dict_obj contains the expected_subset.
+ """
+
+ for key, value in expected_subset.items():
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+ return False
+ return True
+
+
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+ """Check if attribute of class object is correct.
+
+ Args:
+ obj (object): Class object to be checked.
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+
+ Returns:
+ bool: Whether the attribute of class object is correct.
+ """
+ for attr, value in expected_attrs.items():
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+ return False
+ return True
+
+
+def assert_dict_has_keys(obj: Dict[str, Any],
+ expected_keys: List[str]) -> bool:
+ """Check if the obj has all the expected_keys.
+
+ Args:
+ obj (Dict[str, Any]): Object to be checked.
+ expected_keys (List[str]): Keys expected to contained in the keys of
+ the obj.
+
+ Returns:
+ bool: Whether the obj has the expected keys.
+ """
+ return set(expected_keys).issubset(set(obj.keys()))
+
+
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+ """Check if target_keys is equal to result_keys.
+
+ Args:
+ result_keys (List[str]): Result keys to be checked.
+ target_keys (List[str]): Target keys to be checked.
+
+ Returns:
+ bool: Whether target_keys is equal to result_keys.
+ """
+ return set(result_keys) == set(target_keys)
+
+
+def assert_is_norm_layer(module) -> bool:
+ """Check if the module is a norm layer.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the module is a norm layer.
+ """
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
+ from torch.nn import GroupNorm, LayerNorm
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+ return isinstance(module, norm_layer_candidates)
+
+
+def assert_params_all_zeros(module) -> bool:
+ """Check if the parameters of the module is all zeros.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the parameters of the module is all zeros.
+ """
+ weight_data = module.weight.data
+ is_weight_zero = weight_data.allclose(
+ weight_data.new_zeros(weight_data.size()))
+
+ if hasattr(module, 'bias') and module.bias is not None:
+ bias_data = module.bias.data
+ is_bias_zero = bias_data.allclose(
+ bias_data.new_zeros(bias_data.size()))
+ else:
+ is_bias_zero = True
+
+ return is_weight_zero and is_bias_zero
diff --git a/src/custom_mmpkg/custom_mmcv/utils/timer.py b/src/custom_mmpkg/custom_mmcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5907e0edfdee7ab002e41d151e4c4386e1d9f294
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/timer.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from time import time
+
+
+class TimerError(Exception):
+
+ def __init__(self, message):
+ self.message = message
+ super(TimerError, self).__init__(message)
+
+
+class Timer:
+ """A flexible Timer class.
+
+ :Example:
+
+ >>> import time
+ >>> import custom_mmpkg.custom_mmcv as mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+ if start:
+ self.start()
+
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+
+ def since_start(self):
+ """Total time since the timer is started.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ self._t_last = time()
+ return self._t_last - self._t_start
+
+ def since_last_check(self):
+ """Time since the last checking.
+
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+
+
+_g_timers = {} # global timers
+
+
+def check_time(timer_id):
+ """Add check points in a single line.
+
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+
+ :Example:
+
+ >>> import time
+ >>> import custom_mmpkg.custom_mmcv as mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+
+ Args:
+ timer_id (str): Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/src/custom_mmpkg/custom_mmcv/utils/trace.py b/src/custom_mmpkg/custom_mmcv/utils/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..3907185bf82775e8ed4c2bf4cd4667c5c623d188
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/trace.py
@@ -0,0 +1,23 @@
+import warnings
+
+import torch
+
+from custom_mmpkg.custom_mmcv.utils import digit_version
+
+
+def is_jit_tracing() -> bool:
+ if (torch.__version__ != 'parrots'
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
+ on_trace = torch.jit.is_tracing()
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
+ if isinstance(on_trace, bool):
+ return on_trace
+ else:
+ return torch._C._is_tracing()
+ else:
+ warnings.warn(
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
+ 'Therefore is_tracing returns False automatically. Please '
+ 'set on_trace manually if you are using trace.', UserWarning)
+ return False
diff --git a/src/custom_mmpkg/custom_mmcv/utils/version_utils.py b/src/custom_mmpkg/custom_mmcv/utils/version_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/utils/version_utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import subprocess
+import warnings
+
+from packaging.version import parse
+
+
+def digit_version(version_str: str, length: int = 4):
+ """Convert a version string into a tuple of integers.
+
+ This method is usually used for comparing two versions. For pre-release
+ versions: alpha < beta < rc.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int]: The version info in digits (integers).
+ """
+ assert 'parrots' not in version_str
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
+ val = -4
+ # version.pre can be None
+ if version.pre:
+ if version.pre[0] not in mapping:
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+ 'version checking may go wrong')
+ else:
+ val = mapping[version.pre[0]]
+ release.extend([val, version.pre[-1]])
+ else:
+ release.extend([val, 0])
+
+ elif version.is_postrelease:
+ release.extend([1, version.post])
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+
+def get_git_hash(fallback='unknown', digits=None):
+ """Get the git hash of the current repo.
+
+ Args:
+ fallback (str, optional): The fallback string when git hash is
+ unavailable. Defaults to 'unknown'.
+ digits (int, optional): kept digits of the hash. Defaults to None,
+ meaning all digits are kept.
+
+ Returns:
+ str: Git commit hash.
+ """
+
+ if digits is not None and not isinstance(digits, int):
+ raise TypeError('digits must be None or an integer')
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ if digits is not None:
+ sha = sha[:digits]
+ except OSError:
+ sha = fallback
+
+ return sha
diff --git a/src/custom_mmpkg/custom_mmcv/version.py b/src/custom_mmpkg/custom_mmcv/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/version.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+__version__ = '1.3.17'
+
+
+def parse_version_info(version_str: str, length: int = 4) -> tuple:
+ """Parse a version string into a tuple.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
+ """
+ from packaging.version import parse
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ release.extend(list(version.pre))
+ elif version.is_postrelease:
+ release.extend(list(version.post))
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+version_info = tuple(int(x) for x in __version__.split('.')[:3])
+
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/src/custom_mmpkg/custom_mmcv/video/__init__.py b/src/custom_mmpkg/custom_mmcv/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/video/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .io import Cache, VideoReader, frames2video
+from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
+from .processing import concat_video, convert_video, cut_video, resize_video
+
+__all__ = [
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/video/io.py b/src/custom_mmpkg/custom_mmcv/video/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9c20cee37aec3e36413300b88fbdb0156bfa8a4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/video/io.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict
+
+import cv2
+from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+
+from custom_mmpkg.custom_mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
+ track_progress)
+
+
+class Cache:
+
+ def __init__(self, capacity):
+ self._cache = OrderedDict()
+ self._capacity = int(capacity)
+ if capacity <= 0:
+ raise ValueError('capacity must be a positive integer')
+
+ @property
+ def capacity(self):
+ return self._capacity
+
+ @property
+ def size(self):
+ return len(self._cache)
+
+ def put(self, key, val):
+ if key in self._cache:
+ return
+ if len(self._cache) >= self.capacity:
+ self._cache.popitem(last=False)
+ self._cache[key] = val
+
+ def get(self, key, default=None):
+ val = self._cache[key] if key in self._cache else default
+ return val
+
+
+class VideoReader:
+ """Video class with similar usage to a list object.
+
+ This video warpper class provides convenient apis to access frames.
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
+ certain frame may be inaccurate. It is fixed in this class by checking
+ the position after jumping each time.
+ Cache is used when decoding videos. So if the same frame is visited for
+ the second time, there is no need to decode again if it is stored in the
+ cache.
+
+ :Example:
+
+ >>> import custom_mmpkg.custom_mmcv as mmcv
+ >>> v = mmcv.VideoReader('sample.mp4')
+ >>> len(v) # get the total frame number with `len()`
+ 120
+ >>> for img in v: # v is iterable
+ >>> mmcv.imshow(img)
+ >>> v[5] # get the 6th frame
+ """
+
+ def __init__(self, filename, cache_capacity=10):
+ # Check whether the video path is a url
+ if not filename.startswith(('https://', 'http://')):
+ check_file_exist(filename, 'Video file not found: ' + filename)
+ self._vcap = cv2.VideoCapture(filename)
+ assert cache_capacity > 0
+ self._cache = Cache(cache_capacity)
+ self._position = 0
+ # get basic info
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+ self._fps = self._vcap.get(CAP_PROP_FPS)
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+
+ @property
+ def vcap(self):
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+ return self._vcap
+
+ @property
+ def opened(self):
+ """bool: Indicate whether the video is opened."""
+ return self._vcap.isOpened()
+
+ @property
+ def width(self):
+ """int: Width of video frames."""
+ return self._width
+
+ @property
+ def height(self):
+ """int: Height of video frames."""
+ return self._height
+
+ @property
+ def resolution(self):
+ """tuple: Video resolution (width, height)."""
+ return (self._width, self._height)
+
+ @property
+ def fps(self):
+ """float: FPS of the video."""
+ return self._fps
+
+ @property
+ def frame_cnt(self):
+ """int: Total frames of the video."""
+ return self._frame_cnt
+
+ @property
+ def fourcc(self):
+ """str: "Four character code" of the video."""
+ return self._fourcc
+
+ @property
+ def position(self):
+ """int: Current cursor position, indicating frame decoded."""
+ return self._position
+
+ def _get_real_position(self):
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+
+ def _set_real_position(self, frame_id):
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+ pos = self._get_real_position()
+ for _ in range(frame_id - pos):
+ self._vcap.read()
+ self._position = frame_id
+
+ def read(self):
+ """Read the next frame.
+
+ If the next frame have been decoded before and in the cache, then
+ return it directly, otherwise decode, cache and return it.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ # pos = self._position
+ if self._cache:
+ img = self._cache.get(self._position)
+ if img is not None:
+ ret = True
+ else:
+ if self._position != self._get_real_position():
+ self._set_real_position(self._position)
+ ret, img = self._vcap.read()
+ if ret:
+ self._cache.put(self._position, img)
+ else:
+ ret, img = self._vcap.read()
+ if ret:
+ self._position += 1
+ return img
+
+ def get_frame(self, frame_id):
+ """Get frame by index.
+
+ Args:
+ frame_id (int): Index of the expected frame, 0-based.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ if frame_id < 0 or frame_id >= self._frame_cnt:
+ raise IndexError(
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+ if frame_id == self._position:
+ return self.read()
+ if self._cache:
+ img = self._cache.get(frame_id)
+ if img is not None:
+ self._position = frame_id + 1
+ return img
+ self._set_real_position(frame_id)
+ ret, img = self._vcap.read()
+ if ret:
+ if self._cache:
+ self._cache.put(self._position, img)
+ self._position += 1
+ return img
+
+ def current_frame(self):
+ """Get the current frame (frame that is just visited).
+
+ Returns:
+ ndarray or None: If the video is fresh, return None, otherwise
+ return the frame.
+ """
+ if self._position == 0:
+ return None
+ return self._cache.get(self._position - 1)
+
+ def cvt2frames(self,
+ frame_dir,
+ file_start=0,
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ max_num=0,
+ show_progress=True):
+ """Convert a video to frame images.
+
+ Args:
+ frame_dir (str): Output directory to store all the frame images.
+ file_start (int): Filenames will start from the specified number.
+ filename_tmpl (str): Filename template with the index as the
+ placeholder.
+ start (int): The starting frame index.
+ max_num (int): Maximum number of frames to be written.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ mkdir_or_exist(frame_dir)
+ if max_num == 0:
+ task_num = self.frame_cnt - start
+ else:
+ task_num = min(self.frame_cnt - start, max_num)
+ if task_num <= 0:
+ raise ValueError('start must be less than total frame number')
+ if start > 0:
+ self._set_real_position(start)
+
+ def write_frame(file_idx):
+ img = self.read()
+ if img is None:
+ return
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ cv2.imwrite(filename, img)
+
+ if show_progress:
+ track_progress(write_frame, range(file_start,
+ file_start + task_num))
+ else:
+ for i in range(task_num):
+ write_frame(file_start + i)
+
+ def __len__(self):
+ return self.frame_cnt
+
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return [
+ self.get_frame(i)
+ for i in range(*index.indices(self.frame_cnt))
+ ]
+ # support negative indexing
+ if index < 0:
+ index += self.frame_cnt
+ if index < 0:
+ raise IndexError('index out of range')
+ return self.get_frame(index)
+
+ def __iter__(self):
+ self._set_real_position(0)
+ return self
+
+ def __next__(self):
+ img = self.read()
+ if img is not None:
+ return img
+ else:
+ raise StopIteration
+
+ next = __next__
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._vcap.release()
+
+
+def frames2video(frame_dir,
+ video_file,
+ fps=30,
+ fourcc='XVID',
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ end=0,
+ show_progress=True):
+ """Read the frame images from a directory and join them as a video.
+
+ Args:
+ frame_dir (str): The directory containing video frames.
+ video_file (str): Output filename.
+ fps (float): FPS of the output video.
+ fourcc (str): Fourcc of the output video, this should be compatible
+ with the output file type.
+ filename_tmpl (str): Filename template with the index as the variable.
+ start (int): Starting frame index.
+ end (int): Ending frame index.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ if end == 0:
+ ext = filename_tmpl.split('.')[-1]
+ end = len([name for name in scandir(frame_dir, ext)])
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
+ img = cv2.imread(first_file)
+ height, width = img.shape[:2]
+ resolution = (width, height)
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+ resolution)
+
+ def write_frame(file_idx):
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ img = cv2.imread(filename)
+ vwriter.write(img)
+
+ if show_progress:
+ track_progress(write_frame, range(start, end))
+ else:
+ for i in range(start, end):
+ write_frame(i)
+ vwriter.release()
diff --git a/src/custom_mmpkg/custom_mmcv/video/optflow.py b/src/custom_mmpkg/custom_mmcv/video/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c7cc1c48a896191e36d159680df29ac1d70dc4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/video/optflow.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import cv2
+import numpy as np
+
+from custom_mmpkg.custom_mmcv.arraymisc import dequantize, quantize
+from custom_mmpkg.custom_mmcv.image import imread, imwrite
+from custom_mmpkg.custom_mmcv.utils import is_str
+
+
+def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_or_path (ndarray or str): A flow map or filepath.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if isinstance(flow_or_path, np.ndarray):
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
+ return flow_or_path
+ elif not is_str(flow_or_path):
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
+ f'not {type(flow_or_path)}')
+
+ if not quantize:
+ with open(flow_or_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_or_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
+ 'header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+ else:
+ assert concat_axis in [0, 1]
+ cat_flow = imread(flow_or_path, flag='unchanged')
+ if cat_flow.ndim != 2:
+ raise IOError(
+ f'{flow_or_path} is not a valid quantized flow file, '
+ f'its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+ ]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
+ """Use flow to warp img.
+
+ Args:
+ img (ndarray, float or uint8): Image to be warped.
+ flow (ndarray, float): Optical Flow.
+ filling_value (int): The missing pixels will be set with filling_value.
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
+ nearest -> Nearest Neighbor.
+
+ Returns:
+ ndarray: Warped image with the same shape of img
+ """
+ warnings.warn('This function is just for prototyping and cannot '
+ 'guarantee the computational efficiency.')
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
+ height = flow.shape[0]
+ width = flow.shape[1]
+ channels = img.shape[2]
+
+ output = np.ones(
+ (height, width, channels), dtype=img.dtype) * filling_value
+
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
+ dx = grid[:, :, 0] + flow[:, :, 1]
+ dy = grid[:, :, 1] + flow[:, :, 0]
+ sx = np.floor(dx).astype(int)
+ sy = np.floor(dy).astype(int)
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
+
+ if interpolate_mode == 'nearest':
+ output[valid, :] = img[dx[valid].round().astype(int),
+ dy[valid].round().astype(int), :]
+ elif interpolate_mode == 'bilinear':
+ # dirty walkround for integer positions
+ eps_ = 1e-6
+ dx, dy = dx + eps_, dy + eps_
+ left_top_ = img[np.floor(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ right_top_ = img[np.floor(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
+ else:
+ raise NotImplementedError(
+ 'We only support interpolation modes of nearest and bilinear, '
+ f'but got {interpolate_mode}.')
+ return output.astype(img.dtype)
+
+
+def flow_from_bytes(content):
+ """Read dense optical flow from bytes.
+
+ .. note::
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
+ ChairsSDHom.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ ndarray: Loaded optical flow with the shape (H, W, 2).
+ """
+
+ # header in first 4 bytes
+ header = content[:4]
+ if header.decode('utf-8') != 'PIEH':
+ raise Exception('Flow file header does not contain PIEH')
+ # width in second 4 bytes
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
+ # height in third 4 bytes
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
+ # after first 12 bytes, all bytes are flow
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
+ (height, width, 2))
+
+ return flow
+
+
+def sparse_flow_from_bytes(content):
+ """Read the optical flow in KITTI datasets from bytes.
+
+ This function is modified from RAFT load the `KITTI datasets
+ `_.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
+ and flow valid mask with the shape (H, W).
+ """ # nopa
+
+ content = np.frombuffer(content, np.uint8)
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ # flow shape (H, W, 2) valid shape (H, W)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
diff --git a/src/custom_mmpkg/custom_mmcv/video/processing.py b/src/custom_mmpkg/custom_mmcv/video/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..72865d9041f5d8a9717b41b02beca67fa622fd9a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/video/processing.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+
+from custom_mmpkg.custom_mmcv.utils import requires_executable
+
+
+@requires_executable('ffmpeg')
+def convert_video(in_file,
+ out_file,
+ print_cmd=False,
+ pre_options='',
+ **kwargs):
+ """Convert a video with ffmpeg.
+
+ This provides a general api to ffmpeg, the executed command is::
+
+ `ffmpeg -y -i `
+
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
+
+ - key=val: "-key val"
+ - key=True: "-key"
+ - key=False: ""
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ pre_options (str): Options appears before "-i ".
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = []
+ for k, v in kwargs.items():
+ if isinstance(v, bool):
+ if v:
+ options.append(f'-{k}')
+ elif k == 'log_level':
+ assert v in [
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
+ 'verbose', 'debug', 'trace'
+ ]
+ options.append(f'-loglevel {v}')
+ else:
+ options.append(f'-{k} {v}')
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
+ f'{out_file}'
+ if print_cmd:
+ print(cmd)
+ subprocess.call(cmd, shell=True)
+
+
+@requires_executable('ffmpeg')
+def resize_video(in_file,
+ out_file,
+ size=None,
+ ratio=None,
+ keep_ar=False,
+ log_level='info',
+ print_cmd=False):
+ """Resize a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
+ (w*2, h*0.5).
+ keep_ar (bool): Whether to keep original aspect ratio.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ if size is None and ratio is None:
+ raise ValueError('expected size or ratio must be specified')
+ if size is not None and ratio is not None:
+ raise ValueError('size and ratio cannot be specified at the same time')
+ options = {'log_level': log_level}
+ if size:
+ if not keep_ar:
+ options['vf'] = f'scale={size[0]}:{size[1]}'
+ else:
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
+ 'force_original_aspect_ratio=decrease'
+ else:
+ if not isinstance(ratio, tuple):
+ ratio = (ratio, ratio)
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def cut_video(in_file,
+ out_file,
+ start=None,
+ end=None,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Cut a clip from a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ start (None or float): Start time (in seconds).
+ end (None or float): End time (in seconds).
+ vcodec (None or str): Output video codec, None for unchanged.
+ acodec (None or str): Output audio codec, None for unchanged.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ if start:
+ options['ss'] = start
+ else:
+ start = 0
+ if end:
+ options['t'] = end - start
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def concat_video(video_list,
+ out_file,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Concatenate multiple videos into a single one.
+
+ Args:
+ video_list (list): A list of video filenames
+ out_file (str): Output video filename
+ vcodec (None or str): Output video codec, None for unchanged
+ acodec (None or str): Output audio codec, None for unchanged
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
+ with open(tmp_filename, 'w') as f:
+ for filename in video_list:
+ f.write(f'file {osp.abspath(filename)}\n')
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ convert_video(
+ tmp_filename,
+ out_file,
+ print_cmd,
+ pre_options='-f concat -safe 0',
+ **options)
+ os.close(tmp_filehandler)
+ os.remove(tmp_filename)
diff --git a/src/custom_mmpkg/custom_mmcv/visualization/__init__.py b/src/custom_mmpkg/custom_mmcv/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .color import Color, color_val
+from .image import imshow, imshow_bboxes, imshow_det_bboxes
+from .optflow import flow2rgb, flowshow, make_color_wheel
+
+__all__ = [
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
+]
diff --git a/src/custom_mmpkg/custom_mmcv/visualization/color.py b/src/custom_mmpkg/custom_mmcv/visualization/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bff8a9dc94fc5ff8dbd5425faeea165332ac10a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/visualization/color.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+import numpy as np
+
+from custom_mmpkg.custom_mmcv.utils import is_str
+
+
+class Color(Enum):
+ """An enum that defines common colors.
+
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
+ """
+ red = (0, 0, 255)
+ green = (0, 255, 0)
+ blue = (255, 0, 0)
+ cyan = (255, 255, 0)
+ yellow = (0, 255, 255)
+ magenta = (255, 0, 255)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+
+
+def color_val(color):
+ """Convert various input to color tuples.
+
+ Args:
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+
+ Returns:
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
+ """
+ if is_str(color):
+ return Color[color].value
+ elif isinstance(color, Color):
+ return color.value
+ elif isinstance(color, tuple):
+ assert len(color) == 3
+ for channel in color:
+ assert 0 <= channel <= 255
+ return color
+ elif isinstance(color, int):
+ assert 0 <= color <= 255
+ return color, color, color
+ elif isinstance(color, np.ndarray):
+ assert color.ndim == 1 and color.size == 3
+ assert np.all((color >= 0) & (color <= 255))
+ color = color.astype(np.uint8)
+ return tuple(color)
+ else:
+ raise TypeError(f'Invalid type for color: {type(color)}')
diff --git a/src/custom_mmpkg/custom_mmcv/visualization/image.py b/src/custom_mmpkg/custom_mmcv/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f77c6d1033dd2a5968cedf3a5fe77d91cd948b8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/visualization/image.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from custom_mmpkg.custom_mmcv.image import imread, imwrite
+from .color import color_val
+
+
+def imshow(img, win_name='', wait_time=0):
+ """Show an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ cv2.imshow(win_name, imread(img))
+ if wait_time == 0: # prevent from hanging if windows was closed
+ while True:
+ ret = cv2.waitKey(1)
+
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
+ # if user closed window or if some key pressed
+ if closed or ret != -1:
+ break
+ else:
+ ret = cv2.waitKey(wait_time)
+
+
+def imshow_bboxes(img,
+ bboxes,
+ colors='green',
+ top_k=-1,
+ thickness=1,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
+ colors (list[str or tuple or Color]): A list of colors.
+ top_k (int): Plot the first k bboxes only if set positive.
+ thickness (int): Thickness of lines.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str, optional): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if isinstance(bboxes, np.ndarray):
+ bboxes = [bboxes]
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+
+ for i, _bboxes in enumerate(bboxes):
+ _bboxes = _bboxes.astype(np.int32)
+ if top_k <= 0:
+ _top_k = _bboxes.shape[0]
+ else:
+ _top_k = min(top_k, _bboxes.shape[0])
+ for j in range(_top_k):
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
+ cv2.rectangle(
+ img, left_top, right_bottom, colors[i], thickness=thickness)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
+
+
+def imshow_det_bboxes(img,
+ bboxes,
+ labels,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes.ndim == 2
+ assert labels.ndim == 1
+ assert bboxes.shape[0] == labels.shape[0]
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if score_thr > 0:
+ assert bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+
+ bbox_color = color_val(bbox_color)
+ text_color = color_val(text_color)
+
+ for bbox, label in zip(bboxes, labels):
+ bbox_int = bbox.astype(np.int32)
+ left_top = (bbox_int[0], bbox_int[1])
+ right_bottom = (bbox_int[2], bbox_int[3])
+ cv2.rectangle(
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
+ label_text = class_names[
+ label] if class_names is not None else f'cls {label}'
+ if len(bbox) > 4:
+ label_text += f'|{bbox[-1]:.02f}'
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
diff --git a/src/custom_mmpkg/custom_mmcv/visualization/optflow.py b/src/custom_mmpkg/custom_mmcv/visualization/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b13b411f7161205eba2653c357a84f8916a353a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmcv/visualization/optflow.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+
+import numpy as np
+
+from custom_mmpkg.custom_mmcv.image import rgb2bgr
+from custom_mmpkg.custom_mmcv.video import flowread
+from .image import imshow
+
+
+def flowshow(flow, win_name='', wait_time=0):
+ """Show optical flow.
+
+ Args:
+ flow (ndarray or str): The optical flow to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ flow = flowread(flow)
+ flow_img = flow2rgb(flow)
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
+
+
+def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
+ """Convert flow map to RGB image.
+
+ Args:
+ flow (ndarray): Array of optical flow.
+ color_wheel (ndarray or None): Color wheel used to map flow field to
+ RGB colorspace. Default color wheel will be used if not specified.
+ unknown_thr (str): Values above this threshold will be marked as
+ unknown and thus ignored.
+
+ Returns:
+ ndarray: RGB image that can be visualized.
+ """
+ assert flow.ndim == 3 and flow.shape[-1] == 2
+ if color_wheel is None:
+ color_wheel = make_color_wheel()
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
+ num_bins = color_wheel.shape[0]
+
+ dx = flow[:, :, 0].copy()
+ dy = flow[:, :, 1].copy()
+
+ ignore_inds = (
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
+ (np.abs(dy) > unknown_thr))
+ dx[ignore_inds] = 0
+ dy[ignore_inds] = 0
+
+ rad = np.sqrt(dx**2 + dy**2)
+ if np.any(rad > np.finfo(float).eps):
+ max_rad = np.max(rad)
+ dx /= max_rad
+ dy /= max_rad
+
+ rad = np.sqrt(dx**2 + dy**2)
+ angle = np.arctan2(-dy, -dx) / np.pi
+
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
+ bin_left = np.floor(bin_real).astype(int)
+ bin_right = (bin_left + 1) % num_bins
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
+ flow_img = (1 -
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
+ small_ind = rad <= 1
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
+ flow_img[np.logical_not(small_ind)] *= 0.75
+
+ flow_img[ignore_inds, :] = 0
+
+ return flow_img
+
+
+def make_color_wheel(bins=None):
+ """Build a color wheel.
+
+ Args:
+ bins(list or tuple, optional): Specify the number of bins for each
+ color range, corresponding to six ranges: red -> yellow,
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
+ (see Middlebury).
+
+ Returns:
+ ndarray: Color wheel of shape (total_bins, 3).
+ """
+ if bins is None:
+ bins = [15, 6, 4, 11, 13, 6]
+ assert len(bins) == 6
+
+ RY, YG, GC, CB, BM, MR = tuple(bins)
+
+ ry = [1, np.arange(RY) / RY, 0]
+ yg = [1 - np.arange(YG) / YG, 1, 0]
+ gc = [0, 1, np.arange(GC) / GC]
+ cb = [0, 1 - np.arange(CB) / CB, 1]
+ bm = [np.arange(BM) / BM, 0, 1]
+ mr = [1, 0, 1 - np.arange(MR) / MR]
+
+ num_bins = RY + YG + GC + CB + BM + MR
+
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
+
+ col = 0
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
+ for j in range(3):
+ color_wheel[j, col:col + bins[i]] = color[j]
+ col += bins[i]
+
+ return color_wheel.T
diff --git a/src/custom_mmpkg/custom_mmseg/apis/__init__.py b/src/custom_mmpkg/custom_mmseg/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/apis/__init__.py
@@ -0,0 +1,9 @@
+from .inference import inference_segmentor, init_segmentor, show_result_pyplot
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_segmentor
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
+ 'show_result_pyplot'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/apis/inference.py b/src/custom_mmpkg/custom_mmseg/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ee57d61e59f67926c7be6a139d057805026b816
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/apis/inference.py
@@ -0,0 +1,137 @@
+import matplotlib.pyplot as plt
+import custom_mmpkg.custom_mmcv as mmcv
+import torch
+from custom_mmpkg.custom_mmcv.parallel import collate, scatter
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+
+from custom_mmpkg.custom_mmseg.datasets.pipelines import Compose
+from custom_mmpkg.custom_mmseg.models import build_segmentor
+
+
+def init_segmentor(config, checkpoint=None, device="cpu"):
+ """Initialize a segmentor from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+ Use 'cpu' for loading model on CPU.
+ Returns:
+ nn.Module: The constructed segmentor.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ 'but got {}'.format(type(config)))
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ model.PALETTE = checkpoint['meta']['PALETTE']
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+class LoadImage:
+ """A simple pipeline to load image."""
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_segmentor(model, img):
+ """Inference image(s) with the segmentor.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+ images.
+
+ Returns:
+ (list[Tensor]): The segmentation result.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # build the data pipeline
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+ test_pipeline = Compose(test_pipeline)
+ # prepare data
+ data = dict(img=img)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ data['img'][0] = data['img'][0].to(device)
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
+
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ return result
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ palette=None,
+ fig_size=(15, 10),
+ opacity=0.5,
+ title='',
+ block=True):
+ """Visualize the segmentation results on the image.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (list): The segmentation result.
+ palette (list[list[int]]] | None): The palette of segmentation
+ map. If None is given, random palette will be generated.
+ Default: None
+ fig_size (tuple): Figure size of the pyplot figure.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ title (str): The title of pyplot figure.
+ Default is ''.
+ block (bool): Whether to block the pyplot figure.
+ Default is True.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ img = model.show_result(
+ img, result, palette=palette, show=False, opacity=opacity)
+ # plt.figure(figsize=fig_size)
+ # plt.imshow(mmcv.bgr2rgb(img))
+ # plt.title(title)
+ # plt.tight_layout()
+ # plt.show(block=block)
+ return mmcv.bgr2rgb(img)
diff --git a/src/custom_mmpkg/custom_mmseg/apis/test.py b/src/custom_mmpkg/custom_mmseg/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d0078b5b52eca53ddb0c4bb28adb7b1afe59728
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/apis/test.py
@@ -0,0 +1,238 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from custom_mmpkg.custom_mmcv.image import tensor2imgs
+from custom_mmpkg.custom_mmcv.runner import get_dist_info
+
+
+def np2tmp(array, temp_file_name=None):
+ """Save ndarray to local numpy file.
+
+ Args:
+ array (ndarray): Ndarray to save.
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
+ function will generate a file name with tempfile.NamedTemporaryFile
+ to save ndarray. Default: None.
+
+ Returns:
+ str: The numpy file name.
+ """
+
+ if temp_file_name is None:
+ temp_file_name = tempfile.NamedTemporaryFile(
+ suffix='.npy', delete=False).name
+ np.save(temp_file_name, array)
+ return temp_file_name
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ efficient_test=False,
+ opacity=0.5):
+ """Test with single GPU.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ show (bool): Whether show results during inference. Default: False.
+ out_dir (str, optional): If specified, the results will be dumped into
+ the directory to save output results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+
+ if show or out_dir:
+ img_tensor = data['img'][0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+
+ for img, img_meta in zip(imgs, img_metas):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result,
+ palette=dataset.PALETTE,
+ show=show,
+ out_file=out_file,
+ opacity=opacity)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model,
+ data_loader,
+ tmpdir=None,
+ gpu_collect=False,
+ efficient_test=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ if rank == 0:
+ batch_size = data['img'][0].size(0)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results with CPU."""
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ tmpdir = tempfile.mkdtemp()
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results with GPU."""
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/src/custom_mmpkg/custom_mmseg/apis/train.py b/src/custom_mmpkg/custom_mmseg/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..61eb4768b375cf8e3cd5323d5533221e8238c4c8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/apis/train.py
@@ -0,0 +1,116 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from custom_mmpkg.custom_mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from custom_mmpkg.custom_mmcv.runner import build_optimizer, build_runner
+
+from custom_mmpkg.custom_mmseg.core import DistEvalHook, EvalHook
+from custom_mmpkg.custom_mmseg.datasets import build_dataloader, build_dataset
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_segmentor(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ """Launch segmentor training."""
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ drop_last=True) for ds in dataset
+ ]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ if cfg.get('runner') is None:
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+
+ # an ugly walkaround to make the .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # register eval hooks
+ if validate:
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/src/custom_mmpkg/custom_mmseg/core/__init__.py b/src/custom_mmpkg/custom_mmseg/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import * # noqa: F401, F403
+from .seg import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/src/custom_mmpkg/custom_mmseg/core/evaluation/__init__.py b/src/custom_mmpkg/custom_mmseg/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .class_names import get_classes, get_palette
+from .eval_hooks import DistEvalHook, EvalHook
+from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
+
+__all__ = [
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
+ 'eval_metrics', 'get_classes', 'get_palette'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/core/evaluation/class_names.py b/src/custom_mmpkg/custom_mmseg/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e79082966879d06da504a8105646257f103a07c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/evaluation/class_names.py
@@ -0,0 +1,152 @@
+import custom_mmpkg.custom_mmcv as mmcv
+
+
+def cityscapes_classes():
+ """Cityscapes class names for external use."""
+ return [
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+def ade_classes():
+ """ADE20K class names for external use."""
+ return [
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag'
+ ]
+
+
+def voc_classes():
+ """Pascal VOC class names for external use."""
+ return [
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor'
+ ]
+
+
+def cityscapes_palette():
+ """Cityscapes palette for external use."""
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32]]
+
+
+def ade_palette():
+ """ADE20K palette for external use."""
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+
+def voc_palette():
+ """Pascal VOC palette for external use."""
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+
+dataset_aliases = {
+ 'cityscapes': ['cityscapes'],
+ 'ade': ['ade', 'ade20k'],
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
+}
+
+
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
+
+
+def get_palette(dataset):
+ """Get class palette (RGB) of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_palette()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/src/custom_mmpkg/custom_mmseg/core/evaluation/eval_hooks.py b/src/custom_mmpkg/custom_mmseg/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..684fd6c291bae6255cd835ba3d32c1cacca536c8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/evaluation/eval_hooks.py
@@ -0,0 +1,109 @@
+import os.path as osp
+
+from custom_mmpkg.custom_mmcv.runner import DistEvalHook as _DistEvalHook
+from custom_mmpkg.custom_mmcv.runner import EvalHook as _EvalHook
+
+
+class EvalHook(_EvalHook):
+ """Single GPU EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from custom_mmpkg.custom_mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(
+ runner.model,
+ self.dataloader,
+ show=False,
+ efficient_test=self.efficient_test)
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from custom_mmpkg.custom_mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ self.evaluate(runner, results)
+
+
+class DistEvalHook(_DistEvalHook):
+ """Distributed EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from custom_mmpkg.custom_mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect,
+ efficient_test=self.efficient_test)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from custom_mmpkg.custom_mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
diff --git a/src/custom_mmpkg/custom_mmseg/core/evaluation/metrics.py b/src/custom_mmpkg/custom_mmseg/core/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4b29f6c277ce43e4a0f39c3898a2938e11dba8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/evaluation/metrics.py
@@ -0,0 +1,326 @@
+from collections import OrderedDict
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+import torch
+
+
+def f_score(precision, recall, beta=1):
+ """calcuate the f-score value.
+
+ Args:
+ precision (float | torch.Tensor): The precision value.
+ recall (float | torch.Tensor): The recall value.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+ Returns:
+ [torch.tensor]: The f-score value.
+ """
+ score = (1 + beta**2) * (precision * recall) / (
+ (beta**2 * precision) + recall)
+ return score
+
+
+def intersect_and_union(pred_label,
+ label,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate intersection and Union.
+
+ Args:
+ pred_label (ndarray | str): Prediction segmentation map
+ or predict result filename.
+ label (ndarray | str): Ground truth segmentation map
+ or label filename.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. The parameter will
+ work only when label is str. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. The parameter will
+ work only when label is str. Default: False.
+
+ Returns:
+ torch.Tensor: The intersection of prediction and ground truth
+ histogram on all classes.
+ torch.Tensor: The union of prediction and ground truth histogram on
+ all classes.
+ torch.Tensor: The prediction histogram on all classes.
+ torch.Tensor: The ground truth histogram on all classes.
+ """
+
+ if isinstance(pred_label, str):
+ pred_label = torch.from_numpy(np.load(pred_label))
+ else:
+ pred_label = torch.from_numpy((pred_label))
+
+ if isinstance(label, str):
+ label = torch.from_numpy(
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
+ else:
+ label = torch.from_numpy(label)
+
+ if label_map is not None:
+ for old_id, new_id in label_map.items():
+ label[label == old_id] = new_id
+ if reduce_zero_label:
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+
+ mask = (label != ignore_index)
+ pred_label = pred_label[mask]
+ label = label[mask]
+
+ intersect = pred_label[pred_label == label]
+ area_intersect = torch.histc(
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_pred_label = torch.histc(
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_label = torch.histc(
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_union = area_pred_label + area_label - area_intersect
+ return area_intersect, area_union, area_pred_label, area_label
+
+
+def total_intersect_and_union(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Total Intersection and Union.
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ ndarray: The intersection of prediction and ground truth histogram
+ on all classes.
+ ndarray: The union of prediction and ground truth histogram on all
+ classes.
+ ndarray: The prediction histogram on all classes.
+ ndarray: The ground truth histogram on all classes.
+ """
+ num_imgs = len(results)
+ assert len(gt_seg_maps) == num_imgs
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ for i in range(num_imgs):
+ area_intersect, area_union, area_pred_label, area_label = \
+ intersect_and_union(
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
+ label_map, reduce_zero_label)
+ total_area_intersect += area_intersect
+ total_area_union += area_union
+ total_area_pred_label += area_pred_label
+ total_area_label += area_label
+ return total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label
+
+
+def mean_iou(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category IoU, shape (num_classes, ).
+ """
+ iou_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return iou_result
+
+
+def mean_dice(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Dice (mDice)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category dice, shape (num_classes, ).
+ """
+
+ dice_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mDice'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return dice_result
+
+
+def mean_fscore(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category recall, shape (num_classes, ).
+ ndarray: Per category precision, shape (num_classes, ).
+ ndarray: Per category f-score, shape (num_classes, ).
+ """
+ fscore_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mFscore'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label,
+ beta=beta)
+ return fscore_result
+
+
+def eval_metrics(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate evaluation metrics
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+ Returns:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
+ """
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metrics).issubset(set(allowed_metrics)):
+ raise KeyError('metrics {} is not supported'.format(metrics))
+
+ total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label = total_intersect_and_union(
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
+ reduce_zero_label)
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
+ ret_metrics = OrderedDict({'aAcc': all_acc})
+ for metric in metrics:
+ if metric == 'mIoU':
+ iou = total_area_intersect / total_area_union
+ acc = total_area_intersect / total_area_label
+ ret_metrics['IoU'] = iou
+ ret_metrics['Acc'] = acc
+ elif metric == 'mDice':
+ dice = 2 * total_area_intersect / (
+ total_area_pred_label + total_area_label)
+ acc = total_area_intersect / total_area_label
+ ret_metrics['Dice'] = dice
+ ret_metrics['Acc'] = acc
+ elif metric == 'mFscore':
+ precision = total_area_intersect / total_area_pred_label
+ recall = total_area_intersect / total_area_label
+ f_value = torch.tensor(
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
+ ret_metrics['Fscore'] = f_value
+ ret_metrics['Precision'] = precision
+ ret_metrics['Recall'] = recall
+
+ ret_metrics = {
+ metric: value.numpy()
+ for metric, value in ret_metrics.items()
+ }
+ if nan_to_num is not None:
+ ret_metrics = OrderedDict({
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
+ for metric, metric_value in ret_metrics.items()
+ })
+ return ret_metrics
diff --git a/src/custom_mmpkg/custom_mmseg/core/seg/__init__.py b/src/custom_mmpkg/custom_mmseg/core/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/seg/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_pixel_sampler
+from .sampler import BasePixelSampler, OHEMPixelSampler
+
+__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
diff --git a/src/custom_mmpkg/custom_mmseg/core/seg/builder.py b/src/custom_mmpkg/custom_mmseg/core/seg/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6971fce1e60b12c521413bf62127da76f441d4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/seg/builder.py
@@ -0,0 +1,8 @@
+from custom_mmpkg.custom_mmcv.utils import Registry, build_from_cfg
+
+PIXEL_SAMPLERS = Registry('pixel sampler')
+
+
+def build_pixel_sampler(cfg, **default_args):
+ """Build pixel sampler for segmentation map."""
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
diff --git a/src/custom_mmpkg/custom_mmseg/core/seg/sampler/__init__.py b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_pixel_sampler import BasePixelSampler
+from .ohem_pixel_sampler import OHEMPixelSampler
+
+__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
diff --git a/src/custom_mmpkg/custom_mmseg/core/seg/sampler/base_pixel_sampler.py b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/base_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/base_pixel_sampler.py
@@ -0,0 +1,12 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BasePixelSampler(metaclass=ABCMeta):
+ """Base class of pixel sampler."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def sample(self, seg_logit, seg_label):
+ """Placeholder for sample function."""
diff --git a/src/custom_mmpkg/custom_mmseg/core/seg/sampler/ohem_pixel_sampler.py b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/ohem_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+
+from ..builder import PIXEL_SAMPLERS
+from .base_pixel_sampler import BasePixelSampler
+
+
+@PIXEL_SAMPLERS.register_module()
+class OHEMPixelSampler(BasePixelSampler):
+ """Online Hard Example Mining Sampler for segmentation.
+
+ Args:
+ context (nn.Module): The context of sampler, subclass of
+ :obj:`BaseDecodeHead`.
+ thresh (float, optional): The threshold for hard example selection.
+ Below which, are prediction with low confidence. If not
+ specified, the hard examples will be pixels of top ``min_kept``
+ loss. Default: None.
+ min_kept (int, optional): The minimum number of predictions to keep.
+ Default: 100000.
+ """
+
+ def __init__(self, context, thresh=None, min_kept=100000):
+ super(OHEMPixelSampler, self).__init__()
+ self.context = context
+ assert min_kept > 1
+ self.thresh = thresh
+ self.min_kept = min_kept
+
+ def sample(self, seg_logit, seg_label):
+ """Sample pixels that have high loss or with low prediction confidence.
+
+ Args:
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
+
+ Returns:
+ torch.Tensor: segmentation weight, shape (N, H, W)
+ """
+ with torch.no_grad():
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
+ assert seg_label.shape[1] == 1
+ seg_label = seg_label.squeeze(1).long()
+ batch_kept = self.min_kept * seg_label.size(0)
+ valid_mask = seg_label != self.context.ignore_index
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
+ valid_seg_weight = seg_weight[valid_mask]
+ if self.thresh is not None:
+ seg_prob = F.softmax(seg_logit, dim=1)
+
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
+
+ if sort_prob.numel() > 0:
+ min_threshold = sort_prob[min(batch_kept,
+ sort_prob.numel() - 1)]
+ else:
+ min_threshold = 0.0
+ threshold = max(min_threshold, self.thresh)
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
+ else:
+ losses = self.context.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=None,
+ ignore_index=self.context.ignore_index,
+ reduction_override='none')
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
+ _, sort_indices = losses[valid_mask].sort(descending=True)
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
+
+ seg_weight[valid_mask] = valid_seg_weight
+
+ return seg_weight
diff --git a/src/custom_mmpkg/custom_mmseg/core/utils/__init__.py b/src/custom_mmpkg/custom_mmseg/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/utils/__init__.py
@@ -0,0 +1,3 @@
+from .misc import add_prefix
+
+__all__ = ['add_prefix']
diff --git a/src/custom_mmpkg/custom_mmseg/core/utils/misc.py b/src/custom_mmpkg/custom_mmseg/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/core/utils/misc.py
@@ -0,0 +1,17 @@
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f'{prefix}.{name}'] = value
+
+ return outputs
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/__init__.py b/src/custom_mmpkg/custom_mmseg/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/__init__.py
@@ -0,0 +1,19 @@
+from .ade import ADE20KDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .chase_db1 import ChaseDB1Dataset
+from .cityscapes import CityscapesDataset
+from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .drive import DRIVEDataset
+from .hrf import HRFDataset
+from .pascal_context import PascalContextDataset, PascalContextDataset59
+from .stare import STAREDataset
+from .voc import PascalVOCDataset
+
+__all__ = [
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
+ 'STAREDataset'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/ade.py b/src/custom_mmpkg/custom_mmseg/datasets/ade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/ade.py
@@ -0,0 +1,84 @@
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ADE20KDataset(CustomDataset):
+ """ADE20K dataset.
+
+ In segmentation map annotation for ADE20K, 0 stands for background, which
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+ CLASSES = (
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ def __init__(self, **kwargs):
+ super(ADE20KDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=True,
+ **kwargs)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/builder.py b/src/custom_mmpkg/custom_mmseg/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d09e961eafb5301c98fd3defeb558f9b7e938e7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/builder.py
@@ -0,0 +1,161 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from custom_mmpkg.custom_mmcv.parallel import collate
+from custom_mmpkg.custom_mmcv.runner import get_dist_info
+from custom_mmpkg.custom_mmcv.utils import Registry, build_from_cfg
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
+from torch.utils.data import DistributedSampler
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ """Build :obj:`ConcatDataset by."""
+ from .dataset_wrappers import ConcatDataset
+ img_dir = cfg['img_dir']
+ ann_dir = cfg.get('ann_dir', None)
+ split = cfg.get('split', None)
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
+ if ann_dir is not None:
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
+ else:
+ num_ann_dir = 0
+ if split is not None:
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
+ else:
+ num_split = 0
+ if num_img_dir > 1:
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
+ assert num_img_dir == num_split or num_split == 0
+ else:
+ assert num_split == num_ann_dir or num_ann_dir <= 1
+ num_dset = max(num_split, num_img_dir)
+
+ datasets = []
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ if isinstance(img_dir, (list, tuple)):
+ data_cfg['img_dir'] = img_dir[i]
+ if isinstance(ann_dir, (list, tuple)):
+ data_cfg['ann_dir'] = ann_dir[i]
+ if isinstance(split, (list, tuple)):
+ data_cfg['split'] = split[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+ """Build datasets."""
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
+ cfg.get('split', None), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ drop_last=False,
+ pin_memory=True,
+ dataloader_type='PoolDataLoader',
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ seed (int | None): Seed to be used. Default: None.
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
+ Default: False
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
+ Default: True
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=shuffle)
+ shuffle = False
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ assert dataloader_type in (
+ 'DataLoader',
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
+
+ if dataloader_type == 'PoolDataLoader':
+ dataloader = PoolDataLoader
+ elif dataloader_type == 'DataLoader':
+ dataloader = DataLoader
+
+ data_loader = dataloader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ worker_init_fn=init_fn,
+ drop_last=drop_last,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/chase_db1.py b/src/custom_mmpkg/custom_mmseg/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/chase_db1.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ChaseDB1Dataset(CustomDataset):
+ """Chase_db1 dataset.
+
+ In segmentation map annotation for Chase_db1, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_1stHO.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(ChaseDB1Dataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_1stHO.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/cityscapes.py b/src/custom_mmpkg/custom_mmseg/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4d09372290d8d1d35fc75846a2802417d6b0db
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/cityscapes.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import tempfile
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+from custom_mmpkg.custom_mmcv.utils import print_log
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CustomDataset):
+ """Cityscapes dataset.
+
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
+ """
+
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
+
+ def __init__(self, **kwargs):
+ super(CityscapesDataset, self).__init__(
+ img_suffix='_leftImg8bit.png',
+ seg_map_suffix='_gtFine_labelTrainIds.png',
+ **kwargs)
+
+ @staticmethod
+ def _convert_to_label_id(result):
+ """Convert trainId to id for cityscapes."""
+ if isinstance(result, str):
+ result = np.load(result)
+ import cityscapesscripts.helpers.labels as CSLabels
+ result_copy = result.copy()
+ for trainId, label in CSLabels.trainId2label.items():
+ result_copy[result == trainId] = label.id
+
+ return result_copy
+
+ def results2img(self, results, imgfile_prefix, to_label_id):
+ """Write the segmentation results to images.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ imgfile_prefix (str): The filename prefix of the png files.
+ If the prefix is "somepath/xxx",
+ the png files will be named "somepath/xxx.png".
+ to_label_id (bool): whether convert output to label_id for
+ submission
+
+ Returns:
+ list[str: str]: result txt files which contains corresponding
+ semantic segmentation images.
+ """
+ mmcv.mkdir_or_exist(imgfile_prefix)
+ result_files = []
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ if to_label_id:
+ result = self._convert_to_label_id(result)
+ filename = self.img_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
+
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
+ import cityscapesscripts.helpers.labels as CSLabels
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
+ for label_id, label in CSLabels.id2label.items():
+ palette[label_id] = label.color
+
+ output.putpalette(palette)
+ output.save(png_filename)
+ result_files.append(png_filename)
+ prog_bar.update()
+
+ return result_files
+
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
+ """Format the results into dir (standard format for Cityscapes
+ evaluation).
+
+ Args:
+ results (list): Testing results of the dataset.
+ imgfile_prefix (str | None): The prefix of images files. It
+ includes the file path and the prefix of filename, e.g.,
+ "a/b/prefix". If not specified, a temp file will be created.
+ Default: None.
+ to_label_id (bool): whether convert output to label_id for
+ submission. Default: False
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a list containing
+ the image paths, tmp_dir is the temporal directory created
+ for saving json/png files when img_prefix is not specified.
+ """
+
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: '
+ f'{len(results)} != {len(self)}')
+
+ if imgfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ imgfile_prefix = tmp_dir.name
+ else:
+ tmp_dir = None
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
+
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ imgfile_prefix=None,
+ efficient_test=False):
+ """Evaluation in Cityscapes/default protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file,
+ for cityscapes evaluation only. It includes the file path and
+ the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output png files. The output files would be
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
+ the image name of cityscapes. If not specified, a temp file
+ will be created for evaluation.
+ Default: None.
+
+ Returns:
+ dict[str, float]: Cityscapes/default metrics.
+ """
+
+ eval_results = dict()
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
+ metrics.remove('cityscapes')
+ if len(metrics) > 0:
+ eval_results.update(
+ super(CityscapesDataset,
+ self).evaluate(results, metrics, logger, efficient_test))
+
+ return eval_results
+
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
+ """Evaluation in Cityscapes protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file
+
+ Returns:
+ dict[str: float]: Cityscapes evaluation results.
+ """
+ try:
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install cityscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
+
+ if tmp_dir is None:
+ result_dir = imgfile_prefix
+ else:
+ result_dir = tmp_dir.name
+
+ eval_results = dict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+ CSEval.args.evalInstLevelScore = True
+ CSEval.args.predictionPath = osp.abspath(result_dir)
+ CSEval.args.evalPixelAccuracy = True
+ CSEval.args.JSONOutput = False
+
+ seg_map_list = []
+ pred_list = []
+
+ # when evaluating with official cityscapesscripts,
+ # **_gtFine_labelIds.png is used
+ for seg_map in mmcv.scandir(
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
+
+ eval_results.update(
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+
+ return eval_results
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/custom.py b/src/custom_mmpkg/custom_mmseg/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..5096a1d718784fcfcc6ae0b30aa256dfb57bc768
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/custom.py
@@ -0,0 +1,403 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+from functools import reduce
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+from custom_mmpkg.custom_mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from custom_mmpkg.custom_mmseg.core import eval_metrics
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+ """Custom dataset for semantic segmentation. An example of file structure
+ is as followed.
+
+ .. code-block:: none
+
+ ├── data
+ │ ├── my_dataset
+ │ │ ├── img_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── ann_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{seg_map_suffix}
+ │ │ │ │ ├── yyy{seg_map_suffix}
+ │ │ │ │ ├── zzz{seg_map_suffix}
+ │ │ │ ├── val
+
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
+
+
+ Args:
+ pipeline (list[dict]): Processing pipeline
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ ann_dir (str, optional): Path to annotation directory. Default: None
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ split (str, optional): Split txt file. If split is specified, only
+ file with suffix in the splits will be loaded. Otherwise, all
+ images in img_dir/ann_dir will be loaded. Default: None
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
+ None.
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default: False
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, and
+ self.PALETTE is None, random palette will be generated.
+ Default: None
+ """
+
+ CLASSES = None
+
+ PALETTE = None
+
+ def __init__(self,
+ pipeline,
+ img_dir,
+ img_suffix='.jpg',
+ ann_dir=None,
+ seg_map_suffix='.png',
+ split=None,
+ data_root=None,
+ test_mode=False,
+ ignore_index=255,
+ reduce_zero_label=False,
+ classes=None,
+ palette=None):
+ self.pipeline = Compose(pipeline)
+ self.img_dir = img_dir
+ self.img_suffix = img_suffix
+ self.ann_dir = ann_dir
+ self.seg_map_suffix = seg_map_suffix
+ self.split = split
+ self.data_root = data_root
+ self.test_mode = test_mode
+ self.ignore_index = ignore_index
+ self.reduce_zero_label = reduce_zero_label
+ self.label_map = None
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
+ classes, palette)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.img_dir):
+ self.img_dir = osp.join(self.data_root, self.img_dir)
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
+ if not (self.split is None or osp.isabs(self.split)):
+ self.split = osp.join(self.data_root, self.split)
+
+ # load annotations
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
+ self.ann_dir,
+ self.seg_map_suffix, self.split)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.img_infos)
+
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
+ split):
+ """Load annotation from directory.
+
+ Args:
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images.
+ ann_dir (str|None): Path to annotation directory.
+ seg_map_suffix (str|None): Suffix of segmentation maps.
+ split (str|None): Split txt file. If split is specified, only file
+ with suffix in the splits will be loaded. Otherwise, all images
+ in img_dir/ann_dir will be loaded. Default: None
+
+ Returns:
+ list[dict]: All image info of dataset.
+ """
+
+ img_infos = []
+ if split is not None:
+ with open(split) as f:
+ for line in f:
+ img_name = line.strip()
+ img_info = dict(filename=img_name + img_suffix)
+ if ann_dir is not None:
+ seg_map = img_name + seg_map_suffix
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+ else:
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
+ img_info = dict(filename=img)
+ if ann_dir is not None:
+ seg_map = img.replace(img_suffix, seg_map_suffix)
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
+ return img_infos
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.img_infos[idx]['ann']
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['seg_fields'] = []
+ results['img_prefix'] = self.img_dir
+ results['seg_prefix'] = self.ann_dir
+ if self.custom_classes:
+ results['label_map'] = self.label_map
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set
+ False).
+ """
+
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ else:
+ return self.prepare_train_img(idx)
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def get_gt_seg_maps(self, efficient_test=False):
+ """Get ground truth segmentation maps for evaluation."""
+ gt_seg_maps = []
+ for img_info in self.img_infos:
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
+ if efficient_test:
+ gt_seg_map = seg_map
+ else:
+ gt_seg_map = mmcv.imread(
+ seg_map, flag='unchanged', backend='pillow')
+ gt_seg_maps.append(gt_seg_map)
+ return gt_seg_maps
+
+ def get_classes_and_palette(self, classes=None, palette=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, random
+ palette will be generated. Default: None
+ """
+ if classes is None:
+ self.custom_classes = False
+ return self.CLASSES, self.PALETTE
+
+ self.custom_classes = True
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ if self.CLASSES:
+ if not set(classes).issubset(self.CLASSES):
+ raise ValueError('classes is not a subset of CLASSES.')
+
+ # dictionary, its keys are the old label ids and its values
+ # are the new label ids.
+ # used for changing pixel labels in load_annotations.
+ self.label_map = {}
+ for i, c in enumerate(self.CLASSES):
+ if c not in class_names:
+ self.label_map[i] = -1
+ else:
+ self.label_map[i] = classes.index(c)
+
+ palette = self.get_palette_for_custom_classes(class_names, palette)
+
+ return class_names, palette
+
+ def get_palette_for_custom_classes(self, class_names, palette=None):
+
+ if self.label_map is not None:
+ # return subset of palette
+ palette = []
+ for old_id, new_id in sorted(
+ self.label_map.items(), key=lambda x: x[1]):
+ if new_id != -1:
+ palette.append(self.PALETTE[old_id])
+ palette = type(self.PALETTE)(palette)
+
+ elif palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
+ else:
+ palette = self.PALETTE
+
+ return palette
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ efficient_test=False,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
+ 'mDice' and 'mFscore' are supported.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str, float]: Default metrics.
+ """
+
+ if isinstance(metric, str):
+ metric = [metric]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metric).issubset(set(allowed_metrics)):
+ raise KeyError('metric {} is not supported'.format(metric))
+ eval_results = {}
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
+ if self.CLASSES is None:
+ num_classes = len(
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
+ else:
+ num_classes = len(self.CLASSES)
+ ret_metrics = eval_metrics(
+ results,
+ gt_seg_maps,
+ num_classes,
+ self.ignore_index,
+ metric,
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label)
+
+ if self.CLASSES is None:
+ class_names = tuple(range(num_classes))
+ else:
+ class_names = self.CLASSES
+
+ # summary table
+ ret_metrics_summary = OrderedDict({
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+
+ # each class table
+ ret_metrics.pop('aAcc', None)
+ ret_metrics_class = OrderedDict({
+ ret_metric: np.round(ret_metric_value * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+ ret_metrics_class.update({'Class': class_names})
+ ret_metrics_class.move_to_end('Class', last=False)
+
+ try:
+ from prettytable import PrettyTable
+ # for logger
+ class_table_data = PrettyTable()
+ for key, val in ret_metrics_class.items():
+ class_table_data.add_column(key, val)
+
+ summary_table_data = PrettyTable()
+ for key, val in ret_metrics_summary.items():
+ if key == 'aAcc':
+ summary_table_data.add_column(key, [val])
+ else:
+ summary_table_data.add_column('m' + key, [val])
+
+ print_log('per class results:', logger)
+ print_log('\n' + class_table_data.get_string(), logger=logger)
+ print_log('Summary:', logger)
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
+ except ImportError: # prettytable is not installed
+ pass
+
+ # each metric dict
+ for key, value in ret_metrics_summary.items():
+ if key == 'aAcc':
+ eval_results[key] = value / 100.0
+ else:
+ eval_results['m' + key] = value / 100.0
+
+ ret_metrics_class.pop('Class', None)
+ for key, value in ret_metrics_class.items():
+ eval_results.update({
+ key + '.' + str(name): value[idx] / 100.0
+ for idx, name in enumerate(class_names)
+ })
+
+ if mmcv.is_list_of(results, str):
+ for file_name in results:
+ os.remove(file_name)
+ return eval_results
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/dataset_wrappers.py b/src/custom_mmpkg/custom_mmseg/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/dataset_wrappers.py
@@ -0,0 +1,50 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.PALETTE = datasets[0].PALETTE
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = dataset.PALETTE
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ """Get item from original dataset."""
+ return self.dataset[idx % self._ori_len]
+
+ def __len__(self):
+ """The length is multiplied by ``times``"""
+ return self.times * self._ori_len
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/drive.py b/src/custom_mmpkg/custom_mmseg/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/drive.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class DRIVEDataset(CustomDataset):
+ """DRIVE dataset.
+
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_manual1.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(DRIVEDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_manual1.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/hrf.py b/src/custom_mmpkg/custom_mmseg/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/hrf.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class HRFDataset(CustomDataset):
+ """HRF dataset.
+
+ In segmentation map annotation for HRF, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(HRFDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pascal_context.py b/src/custom_mmpkg/custom_mmseg/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pascal_context.py
@@ -0,0 +1,103 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalContextDataset(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
+ 'window', 'wood')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
+
+
+@DATASETS.register_module()
+class PascalContextDataset59(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
+
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset59, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=True,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/__init__.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/__init__.py
@@ -0,0 +1,16 @@
+from .compose import Compose
+from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
+ Transpose, to_tensor)
+from .loading import LoadAnnotations, LoadImageFromFile
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/compose.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c8027c235140c6d1cca510bb4d2c81baf439c2
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+
+from custom_mmpkg.custom_mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/formating.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c259f185c9a55faf083dc3bec6d571902125e2d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/formating.py
@@ -0,0 +1,288 @@
+from collections.abc import Sequence
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+import torch
+from custom_mmpkg.custom_mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True),
+ dict(key='gt_semantic_seg'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img',
+ stack=True), dict(key='gt_semantic_seg'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img"
+ and "gt_semantic_seg". These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
+ (3)to DataContainer (stack=True)
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ if 'gt_semantic_seg' in results:
+ # convert to long
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None,
+ ...].astype(np.int64)),
+ stack=True)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "gt_semantic_seg".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple
+ (h, w, c). Note that images may be zero padded on the bottom/right
+ if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/loading.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ef470c7a4b09deaaee6ca145f5f686610e38497
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/loading.py
@@ -0,0 +1,153 @@
+import os.path as osp
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'cv2'
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='cv2'):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('img_prefix') is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, backend=self.imdecode_backend)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(to_float32={self.to_float32},'
+ repr_str += f"color_type='{self.color_type}',"
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+ """Load annotations for semantic segmentation.
+
+ Args:
+ reduce_zero_label (bool): Whether reduce all label value by 1.
+ Usually used for datasets where 0 is background label.
+ Default: False.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'pillow'
+ """
+
+ def __init__(self,
+ reduce_zero_label=False,
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='pillow'):
+ self.reduce_zero_label = reduce_zero_label
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('seg_prefix', None) is not None:
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ else:
+ filename = results['ann_info']['seg_map']
+ img_bytes = self.file_client.get(filename)
+ gt_semantic_seg = mmcv.imfrombytes(
+ img_bytes, flag='unchanged',
+ backend=self.imdecode_backend).squeeze().astype(np.uint8)
+ # modify if custom classes
+ if results.get('label_map', None) is not None:
+ for old_id, new_id in results['label_map'].items():
+ gt_semantic_seg[gt_semantic_seg == old_id] = new_id
+ # reduce zero_label
+ if self.reduce_zero_label:
+ # avoid using underflow conversion
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
+ gt_semantic_seg = gt_semantic_seg - 1
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
+ results['gt_semantic_seg'] = gt_semantic_seg
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/test_time_aug.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..93fe21433378b9c87d9e45243c550755bcafefe5
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,133 @@
+import warnings
+
+import custom_mmpkg.custom_mmcv as mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=(2048, 1024),
+ img_ratios=[0.5, 1.0],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (None | tuple | list[tuple]): Images scales for resizing.
+ img_ratios (float | list[float]): Image ratios for resizing
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale,
+ img_ratios=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ if img_ratios is not None:
+ img_ratios = img_ratios if isinstance(img_ratios,
+ list) else [img_ratios]
+ assert mmcv.is_list_of(img_ratios, float)
+ if img_scale is None:
+ # mode 1: given img_scale=None and a range of image ratio
+ self.img_scale = None
+ assert mmcv.is_list_of(img_ratios, float)
+ elif isinstance(img_scale, tuple) and mmcv.is_list_of(
+ img_ratios, float):
+ assert len(img_scale) == 2
+ # mode 2: given a scale and a range of image ratio
+ self.img_scale = [(int(img_scale[0] * ratio),
+ int(img_scale[1] * ratio))
+ for ratio in img_ratios]
+ else:
+ # mode 3: given multiple scales
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
+ self.flip = flip
+ self.img_ratios = img_ratios
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
+ h, w = results['img'].shape[:2]
+ img_scale = [(int(w * ratio), int(h * ratio))
+ for ratio in self.img_ratios]
+ else:
+ img_scale = self.img_scale
+ flip_aug = [False, True] if self.flip else [False]
+ for scale in img_scale:
+ for flip in flip_aug:
+ for direction in self.flip_direction:
+ _results = results.copy()
+ _results['scale'] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
+ repr_str += f'flip_direction={self.flip_direction}'
+ return repr_str
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/pipelines/transforms.py b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..677191de984592456c145fe83579a049879443d4
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/pipelines/transforms.py
@@ -0,0 +1,889 @@
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+from custom_mmpkg.custom_mmcv.utils import deprecated_api_warning, is_tuple_of
+from numpy import random
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Resize(object):
+ """Resize images & seg.
+
+ This transform resizes the input image to some scale. If the input dict
+ contains the key "scale", then the scale in the input dict is used,
+ otherwise the specified scale in the init method is used.
+
+ ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 4 multiscale modes:
+
+ - ``ratio_range is not None``:
+ 1. When img_scale is None, img_scale is the shape of image in results
+ (img_scale = results['img'].shape[:2]) and the image is resized based
+ on the original size. (mode 1)
+ 2. When img_scale is a tuple (single-scale), randomly sample a ratio from
+ the ratio range and multiply it with the image scale. (mode 2)
+
+ - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
+ scale from the a range. (mode 3)
+
+ - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
+ scale from multiple scales. (mode 4)
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given img_scale=None and a range of image ratio
+ # mode 2: given a scale and a range of image ratio
+ assert self.img_scale is None or len(self.img_scale) == 1
+ else:
+ # mode 3 and 4: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
+ where ``img_scale`` is the selected image scale and
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where
+ ``img_scale`` is sampled scale and None is just a placeholder
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and
+ None is just a placeholder to be consistent with
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ if self.img_scale is None:
+ h, w = results['img'].shape[:2]
+ scale, scale_idx = self.random_sample_ratio((w, h),
+ self.ratio_range)
+ else:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results['img'], results['scale'], return_scale=True)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results['img'].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results['img'], results['scale'], return_scale=True)
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['pad_shape'] = img.shape # in case that there is no padding
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key], results['scale'], interpolation='nearest')
+ else:
+ gt_seg = mmcv.imresize(
+ results[key], results['scale'], interpolation='nearest')
+ results[key] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ self._random_scale(results)
+ self._resize_img(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(img_scale={self.img_scale}, '
+ f'multiscale_mode={self.multiscale_mode}, '
+ f'ratio_range={self.ratio_range}, '
+ f'keep_ratio={self.keep_ratio})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+ """Flip the image & seg.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ Args:
+ prob (float, optional): The flipping probability. Default: None.
+ direction(str, optional): The flipping direction. Options are
+ 'horizontal' and 'vertical'. Default: 'horizontal'.
+ """
+
+ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
+ def __init__(self, prob=None, direction='horizontal'):
+ self.prob = prob
+ self.direction = direction
+ if prob is not None:
+ assert prob >= 0 and prob <= 1
+ assert direction in ['horizontal', 'vertical']
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added into
+ result dict.
+ """
+
+ if 'flip' not in results:
+ flip = True if np.random.rand() < self.prob else False
+ results['flip'] = flip
+ if 'flip_direction' not in results:
+ results['flip_direction'] = self.direction
+ if results['flip']:
+ # flip image
+ results['img'] = mmcv.imflip(
+ results['img'], direction=results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ # use copy() to make numpy stride positive
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction']).copy()
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ """
+
+ def __init__(self,
+ size=None,
+ size_divisor=None,
+ pad_val=0,
+ seg_pad_val=255):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results['img'], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results['img'], self.size_divisor, pad_val=self.pad_val)
+ results['img'] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_seg(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key],
+ shape=results['pad_shape'][:2],
+ pad_val=self.seg_pad_val)
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+
+ self._pad_img(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
+ f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+
+ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
+ f'{self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rerange(object):
+ """Rerange the image pixel value.
+
+ Args:
+ min_value (float or int): Minimum value of the reranged image.
+ Default: 0.
+ max_value (float or int): Maximum value of the reranged image.
+ Default: 255.
+ """
+
+ def __init__(self, min_value=0, max_value=255):
+ assert isinstance(min_value, float) or isinstance(min_value, int)
+ assert isinstance(max_value, float) or isinstance(max_value, int)
+ assert min_value < max_value
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def __call__(self, results):
+ """Call function to rerange images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Reranged results.
+ """
+
+ img = results['img']
+ img_min_value = np.min(img)
+ img_max_value = np.max(img)
+
+ assert img_min_value < img_max_value
+ # rerange to [0, 1]
+ img = (img - img_min_value) / (img_max_value - img_min_value)
+ # rerange to [min_value, max_value]
+ img = img * (self.max_value - self.min_value) + self.min_value
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CLAHE(object):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+ """
+
+ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
+ assert isinstance(clip_limit, (float, int))
+ self.clip_limit = clip_limit
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+ self.tile_grid_size = tile_grid_size
+
+ def __call__(self, results):
+ """Call function to Use CLAHE method process images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ for i in range(results['img'].shape[2]):
+ results['img'][:, :, i] = mmcv.clahe(
+ np.array(results['img'][:, :, i], dtype=np.uint8),
+ self.clip_limit, self.tile_grid_size)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(clip_limit={self.clip_limit}, '\
+ f'tile_grid_size={self.tile_grid_size})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+ """Random crop the image & seg.
+
+ Args:
+ crop_size (tuple): Expected size after cropping, (h, w).
+ cat_max_ratio (float): The maximum ratio that single category could
+ occupy.
+ """
+
+ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ self.cat_max_ratio = cat_max_ratio
+ self.ignore_index = ignore_index
+
+ def get_crop_bbox(self, img):
+ """Randomly get a crop bounding box."""
+ margin_h = max(img.shape[0] - self.crop_size[0], 0)
+ margin_w = max(img.shape[1] - self.crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+
+ return crop_y1, crop_y2, crop_x1, crop_x2
+
+ def crop(self, img, crop_bbox):
+ """Crop from ``img``"""
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ return img
+
+ def __call__(self, results):
+ """Call function to randomly crop images, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+
+ img = results['img']
+ crop_bbox = self.get_crop_bbox(img)
+ if self.cat_max_ratio < 1.:
+ # Repeat 10 times
+ for _ in range(10):
+ seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
+ labels, cnt = np.unique(seg_temp, return_counts=True)
+ cnt = cnt[labels != self.ignore_index]
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
+ cnt) < self.cat_max_ratio:
+ break
+ crop_bbox = self.get_crop_bbox(img)
+
+ # crop the image
+ img = self.crop(img, crop_bbox)
+ img_shape = img.shape
+ results['img'] = img
+ results['img_shape'] = img_shape
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = self.crop(results[key], crop_bbox)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+
+
+@PIPELINES.register_module()
+class RandomRotate(object):
+ """Rotate the image & seg.
+
+ Args:
+ prob (float): The rotation probability.
+ degree (float, tuple[float]): Range of degrees to select from. If
+ degree is a number instead of tuple like (min, max),
+ the range of degree will be (``-degree``, ``+degree``)
+ pad_val (float, optional): Padding value of image. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used. Default: None.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image. Default: False
+ """
+
+ def __init__(self,
+ prob,
+ degree,
+ pad_val=0,
+ seg_pad_val=255,
+ center=None,
+ auto_bound=False):
+ self.prob = prob
+ assert prob >= 0 and prob <= 1
+ if isinstance(degree, (float, int)):
+ assert degree > 0, f'degree {degree} should be positive'
+ self.degree = (-degree, degree)
+ else:
+ self.degree = degree
+ assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
+ f'tuple of (min, max)'
+ self.pal_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ self.center = center
+ self.auto_bound = auto_bound
+
+ def __call__(self, results):
+ """Call function to rotate image, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+
+ rotate = True if np.random.rand() < self.prob else False
+ degree = np.random.uniform(min(*self.degree), max(*self.degree))
+ if rotate:
+ # rotate image
+ results['img'] = mmcv.imrotate(
+ results['img'],
+ angle=degree,
+ border_value=self.pal_val,
+ center=self.center,
+ auto_bound=self.auto_bound)
+
+ # rotate segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imrotate(
+ results[key],
+ angle=degree,
+ border_value=self.seg_pad_val,
+ center=self.center,
+ auto_bound=self.auto_bound,
+ interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob}, ' \
+ f'degree={self.degree}, ' \
+ f'pad_val={self.pal_val}, ' \
+ f'seg_pad_val={self.seg_pad_val}, ' \
+ f'center={self.center}, ' \
+ f'auto_bound={self.auto_bound})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RGB2Gray(object):
+ """Convert RGB image to grayscale image.
+
+ This transform calculate the weighted mean of input image channels with
+ ``weights`` and then expand the channels to ``out_channels``. When
+ ``out_channels`` is None, the number of output channels is the same as
+ input channels.
+
+ Args:
+ out_channels (int): Expected number of output channels after
+ transforming. Default: None.
+ weights (tuple[float]): The weights to calculate the weighted mean.
+ Default: (0.299, 0.587, 0.114).
+ """
+
+ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
+ assert out_channels is None or out_channels > 0
+ self.out_channels = out_channels
+ assert isinstance(weights, tuple)
+ for item in weights:
+ assert isinstance(item, (float, int))
+ self.weights = weights
+
+ def __call__(self, results):
+ """Call function to convert RGB image to grayscale image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with grayscale image.
+ """
+ img = results['img']
+ assert len(img.shape) == 3
+ assert img.shape[2] == len(self.weights)
+ weights = np.array(self.weights).reshape((1, 1, -1))
+ img = (img * weights).sum(2, keepdims=True)
+ if self.out_channels is None:
+ img = img.repeat(weights.shape[2], axis=2)
+ else:
+ img = img.repeat(self.out_channels, axis=2)
+
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(out_channels={self.out_channels}, ' \
+ f'weights={self.weights})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class AdjustGamma(object):
+ """Using gamma correction to process the image.
+
+ Args:
+ gamma (float or int): Gamma value used in gamma correction.
+ Default: 1.0.
+ """
+
+ def __init__(self, gamma=1.0):
+ assert isinstance(gamma, float) or isinstance(gamma, int)
+ assert gamma > 0
+ self.gamma = gamma
+ inv_gamma = 1.0 / gamma
+ self.table = np.array([(i / 255.0)**inv_gamma * 255
+ for i in np.arange(256)]).astype('uint8')
+
+ def __call__(self, results):
+ """Call function to process the image with gamma correction.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ results['img'] = mmcv.lut_transform(
+ np.array(results['img'], dtype=np.uint8), self.table)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(gamma={self.gamma})'
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ """
+
+ def __init__(self, scale_factor=1):
+ self.scale_factor = scale_factor
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key], self.scale_factor, interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def convert(self, img, alpha=1, beta=0):
+ """Multiple with alpha and add beat with clip."""
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+
+ def brightness(self, img):
+ """Brightness distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ beta=random.uniform(-self.brightness_delta,
+ self.brightness_delta))
+ return img
+
+ def contrast(self, img):
+ """Contrast distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
+ return img
+
+ def saturation(self, img):
+ """Saturation distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :, 1] = self.convert(
+ img[:, :, 1],
+ alpha=random.uniform(self.saturation_lower,
+ self.saturation_upper))
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def hue(self, img):
+ """Hue distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :,
+ 0] = (img[:, :, 0].astype(int) +
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ img = results['img']
+ # random brightness
+ img = self.brightness(img)
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ img = self.contrast(img)
+
+ # random saturation
+ img = self.saturation(img)
+
+ # random hue
+ img = self.hue(img)
+
+ # random contrast
+ if mode == 0:
+ img = self.contrast(img)
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
+ f'contrast_range=({self.contrast_lower}, '
+ f'{self.contrast_upper}), '
+ f'saturation_range=({self.saturation_lower}, '
+ f'{self.saturation_upper}), '
+ f'hue_delta={self.hue_delta})')
+ return repr_str
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/stare.py b/src/custom_mmpkg/custom_mmseg/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/stare.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class STAREDataset(CustomDataset):
+ """STARE dataset.
+
+ In segmentation map annotation for STARE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.ah.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(STAREDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.ah.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/src/custom_mmpkg/custom_mmseg/datasets/voc.py b/src/custom_mmpkg/custom_mmseg/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/datasets/voc.py
@@ -0,0 +1,29 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalVOCDataset(CustomDataset):
+ """Pascal VOC dataset.
+
+ Args:
+ split (str): Split txt file for Pascal VOC.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
+ 'train', 'tvmonitor')
+
+ PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalVOCDataset, self).__init__(
+ img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/src/custom_mmpkg/custom_mmseg/models/__init__.py b/src/custom_mmpkg/custom_mmseg/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/__init__.py
@@ -0,0 +1,12 @@
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
+ build_head, build_loss, build_segmentor)
+from .decode_heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .segmentors import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
+ 'build_head', 'build_loss', 'build_segmentor'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/__init__.py b/src/custom_mmpkg/custom_mmseg/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1116c00a17c8bd9ed7f18743baee22b3b7d3f8d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/__init__.py
@@ -0,0 +1,16 @@
+from .cgnet import CGNet
+# from .fast_scnn import FastSCNN
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnext import ResNeXt
+from .unet import UNet
+from .vit import VisionTransformer
+
+__all__ = [
+ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
+ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
+ 'VisionTransformer'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/cgnet.py b/src/custom_mmpkg/custom_mmseg/models/backbones/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b158be8dffa5e119c4f73e84d399815ec714ac
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/cgnet.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from custom_mmpkg.custom_mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
+ constant_init, kaiming_init)
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class GlobalContextExtractor(nn.Module):
+ """Global Context Extractor for CGNet.
+
+ This class is employed to refine the joint feature of both local feature
+ and surrounding context.
+
+ Args:
+ channel (int): Number of input feature channels.
+ reduction (int): Reductions for global context extractor. Default: 16.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self, channel, reduction=16, with_cp=False):
+ super(GlobalContextExtractor, self).__init__()
+ self.channel = channel
+ self.reduction = reduction
+ assert reduction >= 1 and channel >= reduction
+ self.with_cp = with_cp
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ num_batch, num_channel = x.size()[:2]
+ y = self.avg_pool(x).view(num_batch, num_channel)
+ y = self.fc(y).view(num_batch, num_channel, 1, 1)
+ return x * y
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class ContextGuidedBlock(nn.Module):
+ """Context Guided Block for CGNet.
+
+ This class consists of four components: local feature extractor,
+ surrounding feature extractor, joint feature extractor and global
+ context extractor.
+
+ Args:
+ in_channels (int): Number of input feature channels.
+ out_channels (int): Number of output feature channels.
+ dilation (int): Dilation rate for surrounding context extractor.
+ Default: 2.
+ reduction (int): Reduction for global context extractor. Default: 16.
+ skip_connect (bool): Add input to output or not. Default: True.
+ downsample (bool): Downsample the input to 1/2 or not. Default: False.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ dilation=2,
+ reduction=16,
+ skip_connect=True,
+ downsample=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ with_cp=False):
+ super(ContextGuidedBlock, self).__init__()
+ self.with_cp = with_cp
+ self.downsample = downsample
+
+ channels = out_channels if downsample else out_channels // 2
+ if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
+ act_cfg['num_parameters'] = channels
+ kernel_size = 3 if downsample else 1
+ stride = 2 if downsample else 1
+ padding = (kernel_size - 1) // 2
+
+ self.conv1x1 = ConvModule(
+ in_channels,
+ channels,
+ kernel_size,
+ stride,
+ padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.f_loc = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ groups=channels,
+ bias=False)
+ self.f_sur = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=dilation,
+ groups=channels,
+ dilation=dilation,
+ bias=False)
+
+ self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
+ self.activate = nn.PReLU(2 * channels)
+
+ if downsample:
+ self.bottleneck = build_conv_layer(
+ conv_cfg,
+ 2 * channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+
+ self.skip_connect = skip_connect and not downsample
+ self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = self.conv1x1(x)
+ loc = self.f_loc(out)
+ sur = self.f_sur(out)
+
+ joi_feat = torch.cat([loc, sur], 1) # the joint feature
+ joi_feat = self.bn(joi_feat)
+ joi_feat = self.activate(joi_feat)
+ if self.downsample:
+ joi_feat = self.bottleneck(joi_feat) # channel = out_channels
+ # f_glo is employed to refine the joint feature
+ out = self.f_glo(joi_feat)
+
+ if self.skip_connect:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InputInjection(nn.Module):
+ """Downsampling module for CGNet."""
+
+ def __init__(self, num_downsampling):
+ super(InputInjection, self).__init__()
+ self.pool = nn.ModuleList()
+ for i in range(num_downsampling):
+ self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
+
+ def forward(self, x):
+ for pool in self.pool:
+ x = pool(x)
+ return x
+
+
+@BACKBONES.register_module()
+class CGNet(nn.Module):
+ """CGNet backbone.
+
+ A Light-weight Context Guided Network for Semantic Segmentation
+ arXiv: https://arxiv.org/abs/1811.08201
+
+ Args:
+ in_channels (int): Number of input image channels. Normally 3.
+ num_channels (tuple[int]): Numbers of feature channels at each stages.
+ Default: (32, 64, 128).
+ num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
+ Default: (3, 21).
+ dilations (tuple[int]): Dilation rate for surrounding context
+ extractors at stage 1 and stage 2. Default: (2, 4).
+ reductions (tuple[int]): Reductions for global context extractors at
+ stage 1 and stage 2. Default: (8, 16).
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ norm_eval=False,
+ with_cp=False):
+
+ super(CGNet, self).__init__()
+ self.in_channels = in_channels
+ self.num_channels = num_channels
+ assert isinstance(self.num_channels, tuple) and len(
+ self.num_channels) == 3
+ self.num_blocks = num_blocks
+ assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
+ self.dilations = dilations
+ assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
+ self.reductions = reductions
+ assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
+ self.act_cfg['num_parameters'] = num_channels[0]
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ cur_channels = in_channels
+ self.stem = nn.ModuleList()
+ for i in range(3):
+ self.stem.append(
+ ConvModule(
+ cur_channels,
+ num_channels[0],
+ 3,
+ 2 if i == 0 else 1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ cur_channels = num_channels[0]
+
+ self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
+ self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
+
+ cur_channels += in_channels
+ self.norm_prelu_0 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 1
+ self.level1 = nn.ModuleList()
+ for i in range(num_blocks[0]):
+ self.level1.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[1],
+ num_channels[1],
+ dilations[0],
+ reductions[0],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[1] + in_channels
+ self.norm_prelu_1 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 2
+ self.level2 = nn.ModuleList()
+ for i in range(num_blocks[1]):
+ self.level2.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[2],
+ num_channels[2],
+ dilations[1],
+ reductions[1],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[2]
+ self.norm_prelu_2 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ def forward(self, x):
+ output = []
+
+ # stage 0
+ inp_2x = self.inject_2x(x)
+ inp_4x = self.inject_4x(x)
+ for layer in self.stem:
+ x = layer(x)
+ x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
+ output.append(x)
+
+ # stage 1
+ for i, layer in enumerate(self.level1):
+ x = layer(x)
+ if i == 0:
+ down1 = x
+ x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
+ output.append(x)
+
+ # stage 2
+ for i, layer in enumerate(self.level2):
+ x = layer(x)
+ if i == 0:
+ down2 = x
+ x = self.norm_prelu_2(torch.cat([down2, x], 1))
+ output.append(x)
+
+ return output
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ elif isinstance(m, nn.PReLU):
+ constant_init(m, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(CGNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/fast_scnn.py b/src/custom_mmpkg/custom_mmseg/models/backbones/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d06faa7c4e3a0d6e85acaf3f2bd21ec28e1f435
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/fast_scnn.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
+ kaiming_init)
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.models.decode_heads.psp_head import PPM
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import BACKBONES
+from ..utils.inverted_residual import InvertedResidual
+
+
+class LearningToDownsample(nn.Module):
+ """Learning to downsample module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ dw_channels (tuple[int]): Number of output channels of the first and
+ the second depthwise conv (dwconv) layers.
+ out_channels (int): Number of output channels of the whole
+ 'learning to downsample' module.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU')):
+ super(LearningToDownsample, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ dw_channels1 = dw_channels[0]
+ dw_channels2 = dw_channels[1]
+
+ self.conv = ConvModule(
+ in_channels,
+ dw_channels1,
+ 3,
+ stride=2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.dsconv1 = DepthwiseSeparableConvModule(
+ dw_channels1,
+ dw_channels2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ self.dsconv2 = DepthwiseSeparableConvModule(
+ dw_channels2,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.dsconv1(x)
+ x = self.dsconv2(x)
+ return x
+
+
+class GlobalFeatureExtractor(nn.Module):
+ """Global feature extractor module.
+
+ Args:
+ in_channels (int): Number of input channels of the GFE module.
+ Default: 64
+ block_channels (tuple[int]): Tuple of ints. Each int specifies the
+ number of output channels of each Inverted Residual module.
+ Default: (64, 96, 128)
+ out_channels(int): Number of output channels of the GFE module.
+ Default: 128
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ Default: 6
+ num_blocks (tuple[int]): Tuple of ints. Each int specifies the
+ number of times each Inverted Residual module is repeated.
+ The repeated Inverted Residual modules are called a 'group'.
+ Default: (3, 3, 3)
+ strides (tuple[int]): Tuple of ints. Each int specifies
+ the downsampling factor of each 'group'.
+ Default: (2, 2, 1)
+ pool_scales (tuple[int]): Tuple of ints. Each int specifies
+ the parameter required in 'global average pooling' within PPM.
+ Default: (1, 2, 3, 6)
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=64,
+ block_channels=(64, 96, 128),
+ out_channels=128,
+ expand_ratio=6,
+ num_blocks=(3, 3, 3),
+ strides=(2, 2, 1),
+ pool_scales=(1, 2, 3, 6),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(GlobalFeatureExtractor, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ assert len(block_channels) == len(num_blocks) == 3
+ self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
+ num_blocks[0], strides[0],
+ expand_ratio)
+ self.bottleneck2 = self._make_layer(block_channels[0],
+ block_channels[1], num_blocks[1],
+ strides[1], expand_ratio)
+ self.bottleneck3 = self._make_layer(block_channels[1],
+ block_channels[2], num_blocks[2],
+ strides[2], expand_ratio)
+ self.ppm = PPM(
+ pool_scales,
+ block_channels[2],
+ block_channels[2] // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=align_corners)
+ self.out = ConvModule(
+ block_channels[2] * 2,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def _make_layer(self,
+ in_channels,
+ out_channels,
+ blocks,
+ stride=1,
+ expand_ratio=6):
+ layers = [
+ InvertedResidual(
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ norm_cfg=self.norm_cfg)
+ ]
+ for i in range(1, blocks):
+ layers.append(
+ InvertedResidual(
+ out_channels,
+ out_channels,
+ 1,
+ expand_ratio,
+ norm_cfg=self.norm_cfg))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.bottleneck1(x)
+ x = self.bottleneck2(x)
+ x = self.bottleneck3(x)
+ x = torch.cat([x, *self.ppm(x)], dim=1)
+ x = self.out(x)
+ return x
+
+
+class FeatureFusionModule(nn.Module):
+ """Feature fusion module.
+
+ Args:
+ higher_in_channels (int): Number of input channels of the
+ higher-resolution branch.
+ lower_in_channels (int): Number of input channels of the
+ lower-resolution branch.
+ out_channels (int): Number of output channels.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ higher_in_channels,
+ lower_in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(FeatureFusionModule, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.dwconv = ConvModule(
+ lower_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.conv_lower_res = ConvModule(
+ out_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.conv_higher_res = ConvModule(
+ higher_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.relu = nn.ReLU(True)
+
+ def forward(self, higher_res_feature, lower_res_feature):
+ lower_res_feature = resize(
+ lower_res_feature,
+ size=higher_res_feature.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ lower_res_feature = self.dwconv(lower_res_feature)
+ lower_res_feature = self.conv_lower_res(lower_res_feature)
+
+ higher_res_feature = self.conv_higher_res(higher_res_feature)
+ out = higher_res_feature + lower_res_feature
+ return self.relu(out)
+
+
+@BACKBONES.register_module()
+class FastSCNN(nn.Module):
+ """Fast-SCNN Backbone.
+
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ downsample_dw_channels (tuple[int]): Number of output channels after
+ the first conv layer & the second conv layer in
+ Learning-To-Downsample (LTD) module.
+ Default: (32, 48).
+ global_in_channels (int): Number of input channels of
+ Global Feature Extractor(GFE).
+ Equal to number of output channels of LTD.
+ Default: 64.
+ global_block_channels (tuple[int]): Tuple of integers that describe
+ the output channels for each of the MobileNet-v2 bottleneck
+ residual blocks in GFE.
+ Default: (64, 96, 128).
+ global_block_strides (tuple[int]): Tuple of integers
+ that describe the strides (downsampling factors) for each of the
+ MobileNet-v2 bottleneck residual blocks in GFE.
+ Default: (2, 2, 1).
+ global_out_channels (int): Number of output channels of GFE.
+ Default: 128.
+ higher_in_channels (int): Number of input channels of the higher
+ resolution branch in FFM.
+ Equal to global_in_channels.
+ Default: 64.
+ lower_in_channels (int): Number of input channels of the lower
+ resolution branch in FFM.
+ Equal to global_out_channels.
+ Default: 128.
+ fusion_out_channels (int): Number of output channels of FFM.
+ Default: 128.
+ out_indices (tuple): Tuple of indices of list
+ [higher_res_features, lower_res_features, fusion_output].
+ Often set to (0,1,2) to enable aux. heads.
+ Default: (0, 1, 2).
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=3,
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+
+ super(FastSCNN, self).__init__()
+ if global_in_channels != higher_in_channels:
+ raise AssertionError('Global Input Channels must be the same \
+ with Higher Input Channels!')
+ elif global_out_channels != lower_in_channels:
+ raise AssertionError('Global Output Channels must be the same \
+ with Lower Input Channels!')
+
+ self.in_channels = in_channels
+ self.downsample_dw_channels1 = downsample_dw_channels[0]
+ self.downsample_dw_channels2 = downsample_dw_channels[1]
+ self.global_in_channels = global_in_channels
+ self.global_block_channels = global_block_channels
+ self.global_block_strides = global_block_strides
+ self.global_out_channels = global_out_channels
+ self.higher_in_channels = higher_in_channels
+ self.lower_in_channels = lower_in_channels
+ self.fusion_out_channels = fusion_out_channels
+ self.out_indices = out_indices
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.learning_to_downsample = LearningToDownsample(
+ in_channels,
+ downsample_dw_channels,
+ global_in_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.global_feature_extractor = GlobalFeatureExtractor(
+ global_in_channels,
+ global_block_channels,
+ global_out_channels,
+ strides=self.global_block_strides,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.feature_fusion = FeatureFusionModule(
+ higher_in_channels,
+ lower_in_channels,
+ fusion_out_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+
+ def init_weights(self, pretrained=None):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ def forward(self, x):
+ higher_res_features = self.learning_to_downsample(x)
+ lower_res_features = self.global_feature_extractor(higher_res_features)
+ fusion_output = self.feature_fusion(higher_res_features,
+ lower_res_features)
+
+ outs = [higher_res_features, lower_res_features, fusion_output]
+ outs = [outs[i] for i in self.out_indices]
+ return tuple(outs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/hrnet.py b/src/custom_mmpkg/custom_mmseg/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df19e7bef0ccacbef039633fa5c26344593bf3c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/hrnet.py
@@ -0,0 +1,555 @@
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.ops import Upsample, resize
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(nn.Module):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HRModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ """Check branches configuration."""
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
+ f'{len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
+ f'{len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
+ f'{len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Build one branch."""
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ """Build multiple branch."""
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ """Build fuse layer."""
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ # we set align_corners=False for HRNet
+ Upsample(
+ scale_factor=2**(j - i),
+ mode='bilinear',
+ align_corners=False)))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ elif j > i:
+ y = y + resize(
+ self.fuse_layers[i][j](x[j]),
+ size=x[i].shape[2:],
+ mode='bilinear',
+ align_corners=False)
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(nn.Module):
+ """HRNet backbone.
+
+ High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: https://arxiv.org/abs/1904.04514
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Normally 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from custom_mmpkg.custom_mmseg.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False):
+ super(HRNet, self).__init__()
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ """Make each layer."""
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ """Make each stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*hr_modules), in_channels
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v2.py b/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcec93a22124fbc58f84cedd96d11f1e8dd90393
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v2.py
@@ -0,0 +1,180 @@
+import logging
+
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, constant_init, kaiming_init
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+
+
+@BACKBONES.register_module()
+class MobileNetV2(nn.Module):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ strides (Sequence[int], optional): Strides of the first block of each
+ layer. If not specified, default config in ``arch_setting`` will
+ be used.
+ dilations (Sequence[int]): Dilation of each layer.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ # Parameters to build layers. 3 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks.
+ arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
+ [6, 96, 3], [6, 160, 3], [6, 320, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ strides=(1, 2, 2, 2, 1, 2, 1),
+ dilations=(1, 1, 1, 1, 1, 1, 1),
+ out_indices=(1, 2, 4, 6),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV2, self).__init__()
+ self.widen_factor = widen_factor
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == len(self.arch_settings)
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 7):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+
+ if frozen_stages not in range(-1, 7):
+ raise ValueError('frozen_stages must be in range(-1, 7). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks = layer_cfg
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ def make_layer(self, out_channels, num_blocks, stride, dilation,
+ expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): Number of blocks.
+ stride (int): Stride of the first block.
+ dilation (int): Dilation of the first block.
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio.
+ """
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride if i == 0 else 1,
+ expand_ratio=expand_ratio,
+ dilation=dilation if i == 0 else 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v3.py b/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..172103273f385b8dcd4e89a7f8ee0714be87113e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/mobilenet_v3.py
@@ -0,0 +1,255 @@
+import logging
+
+import custom_mmpkg.custom_mmcv as mmcv
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, constant_init, kaiming_init
+from custom_mmpkg.custom_mmcv.cnn.bricks import Conv2dAdaptivePadding
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidualV3 as InvertedResidual
+
+
+@BACKBONES.register_module()
+class MobileNetV3(nn.Module):
+ """MobileNetV3 backbone.
+
+ This backbone is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
+ Default: 'small'.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (tuple[int]): Output from which layer.
+ Default: (0, 1, 12).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
+ [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
+ [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
+ [5, 960, 160, True, 'HSwish', 1],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ out_indices=(0, 1, 12),
+ frozen_stages=-1,
+ reduction_factor=1,
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV3, self).__init__()
+ assert arch in self.arch_settings
+ assert isinstance(reduction_factor, int) and reduction_factor > 0
+ assert mmcv.is_tuple_of(out_indices, int)
+ for index in out_indices:
+ if index not in range(0, len(self.arch_settings[arch]) + 2):
+ raise ValueError(
+ 'the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch])+2}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch])+2}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.reduction_factor = reduction_factor
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.layers = self._make_layer()
+
+ def _make_layer(self):
+ layers = []
+
+ # build the first layer (layer0)
+ in_channels = 16
+ layer = ConvModule(
+ in_channels=3,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ self.add_module('layer0', layer)
+ layers.append('layer0')
+
+ layer_setting = self.arch_settings[self.arch]
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+
+ if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
+ i >= 8:
+ mid_channels = mid_channels // self.reduction_factor
+ out_channels = out_channels // self.reduction_factor
+
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0)))
+ else:
+ se_cfg = None
+
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=(in_channels != mid_channels),
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ in_channels = out_channels
+ layer_name = 'layer{}'.format(i + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # build the last layer
+ # block5 layer12 os=32 for small model
+ # block6 layer16 os=32 for large model
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=576 if self.arch == 'small' else 960,
+ kernel_size=1,
+ stride=1,
+ dilation=4,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ layer_name = 'layer{}'.format(len(layer_setting) + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # next, convert backbone MobileNetV3 to a semantic segmentation version
+ if self.arch == 'small':
+ self.layer4.depthwise_conv.conv.stride = (1, 1)
+ self.layer9.depthwise_conv.conv.stride = (1, 1)
+ for i in range(4, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 9:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+ else:
+ self.layer7.depthwise_conv.conv.stride = (1, 1)
+ self.layer13.depthwise_conv.conv.stride = (1, 1)
+ for i in range(7, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 13:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+
+ return layers
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV3, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/resnest.py b/src/custom_mmpkg/custom_mmseg/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea8fbe3aa2149de6367abf11273ec845a17e013
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/resnest.py
@@ -0,0 +1,314 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from custom_mmpkg.custom_mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d in ResNeSt.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/resnet.py b/src/custom_mmpkg/custom_mmseg/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9585254cabbf84fd54cf1644b6bd7c8304f730b8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/resnet.py
@@ -0,0 +1,688 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from custom_mmpkg.custom_mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+ constant_init, kaiming_init)
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(nn.Module):
+ """Basic block for ResNet."""
+
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(BasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ """Forward function for plugins."""
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default" 3.
+ stem_channels (int): Number of stem channels. Default: 64.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+
+ - position (str, required): Position inside block to insert plugin,
+ options: 'after_conv1', 'after_conv2', 'after_conv3'.
+
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'
+ multi_grid (Sequence[int]|None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from custom_mmpkg.custom_mmseg.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ multi_grid=None,
+ contract_dilation=False,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.multi_grid = multi_grid
+ self.contract_dilation = contract_dilation
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ # multi grid is applied to last layer only
+ stage_multi_grid = multi_grid if i == len(
+ self.stage_blocks) - 1 else None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ multi_grid=stage_multi_grid,
+ contract_dilation=contract_dilation)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i+1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """make plugins for ResNet 'stage_idx'th stage .
+
+ Currently we support to insert 'context_block',
+ 'empirical_attention_block', 'nonlocal_block' into the backbone like
+ ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+
+ An example of plugins format could be :
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+
+ Suppose 'stage_idx=0', the structure of blocks in the stage would be:
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+ If stages is missing, the plugin would be applied to all stages.
+
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ """Make stem layer for ResNet."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze stages param and norm stats."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m, 'conv2_offset'):
+ constant_init(m.conv2_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1c(ResNet):
+ """ResNetV1c variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+ in the input stem with three 3x3 convs.
+
+ References:
+ .. [1] https://arxiv.org/pdf/1812.01187.pdf
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1c, self).__init__(
+ deep_stem=True, avg_down=False, **kwargs)
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+ """ResNetV1d variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/resnext.py b/src/custom_mmpkg/custom_mmseg/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a2910074e1671c2e7db2fd3e86f995c590d18b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/resnext.py
@@ -0,0 +1,145 @@
+import math
+
+from custom_mmpkg.custom_mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Normally 3.
+ num_stages (int): Resnet stages, normally 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from custom_mmpkg.custom_mmseg.models import ResNeXt
+ >>> import torch
+ >>> self = ResNeXt(depth=50)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/unet.py b/src/custom_mmpkg/custom_mmseg/models/backbones/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..694272114506e42ebc2531996432a567e1e588b6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/unet.py
@@ -0,0 +1,429 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from custom_mmpkg.custom_mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, constant_init, kaiming_init)
+from custom_mmpkg.custom_mmcv.runner import load_checkpoint
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import UpConvBlock
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super(BasicConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super(DeconvModule, self).__init__()
+
+ assert (kernel_size - scale_factor >= 0) and\
+ (kernel_size - scale_factor) % 2 == 0,\
+ f'kernel_size should be greater than or equal to scale_factor '\
+ f'and (kernel_size - scale_factor) should be even numbers, '\
+ f'while the kernel size is {kernel_size} and scale_factor is '\
+ f'{scale_factor}.'
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super(InterpConv, self).__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+@BACKBONES.register_module()
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(UNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(UNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (h % whole_downsample_rate == 0) \
+ and (w % whole_downsample_rate == 0),\
+ f'The input image size {(h, w)} should be divisible by the whole '\
+ f'downsample rate {whole_downsample_rate}, when num_stages is '\
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '\
+ f'is {self.downsamples}.'
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/src/custom_mmpkg/custom_mmseg/models/backbones/vit.py b/src/custom_mmpkg/custom_mmseg/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bae38424b69dd6699089163db30fa787efb9ac
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/backbones/vit.py
@@ -0,0 +1,459 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/vision_transformer.py."""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from custom_mmpkg.custom_mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
+ constant_init, kaiming_init, normal_init)
+from custom_mmpkg.custom_mmcv.runner import _load_checkpoint
+from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import _BatchNorm
+
+from custom_mmpkg.custom_mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import DropPath, trunc_normal_
+
+
+class Mlp(nn.Module):
+ """MLP layer for Encoder block.
+
+ Args:
+ in_features(int): Input dimension for the first fully
+ connected layer.
+ hidden_features(int): Output dimension for the first fully
+ connected layer.
+ out_features(int): Output dementsion for the second fully
+ connected layer.
+ act_cfg(dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ drop(float): Drop rate for the dropout layer. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ """Attention layer for Encoder block.
+
+ Args:
+ dim (int): Dimension for the input vector.
+ num_heads (int): Number of parallel attention heads.
+ qkv_bias (bool): Enable bias for qkv if True. Default: False.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for output weights. Default: 0.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ b, n, c = x.shape
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
+ c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ """Implements encoder block with residual connection.
+
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Number of parallel attention heads.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float): Drop rate for mlp output weights. Default: 0.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for attn layer output weights.
+ Default: 0.
+ drop_path (float): Drop rate for paths of model.
+ Default: 0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', requires_grad=True).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ proj_drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ with_cp=False):
+ super(Block, self).__init__()
+ self.with_cp = with_cp
+ _, self.norm1 = build_norm_layer(norm_cfg, dim)
+ self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
+ proj_drop)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ _, self.norm2 = build_norm_layer(norm_cfg, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x + self.drop_path(self.attn(self.norm1(x)))
+ out = out + self.drop_path(self.mlp(self.norm2(out)))
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding.
+
+ Args:
+ img_size (int | tuple): Input image size.
+ default: 224.
+ patch_size (int): Width and height for a patch.
+ default: 16.
+ in_channels (int): Input channels for images. Default: 3.
+ embed_dim (int): The embedding dimension. Default: 768.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768):
+ super(PatchEmbed, self).__init__()
+ if isinstance(img_size, int):
+ self.img_size = (img_size, img_size)
+ elif isinstance(img_size, tuple):
+ self.img_size = img_size
+ else:
+ raise TypeError('img_size must be type of int or tuple')
+ h, w = self.img_size
+ self.patch_size = (patch_size, patch_size)
+ self.num_patches = (h // patch_size) * (w // patch_size)
+ self.proj = Conv2d(
+ in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ return self.proj(x).flatten(2).transpose(1, 2)
+
+
+@BACKBONES.register_module()
+class VisionTransformer(nn.Module):
+ """Vision transformer backbone.
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+ Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+
+ Args:
+ img_size (tuple): input image size. Default: (224, 224).
+ patch_size (int, tuple): patch size. Default: 16.
+ in_channels (int): number of input channels. Default: 3.
+ embed_dim (int): embedding dimension. Default: 768.
+ depth (int): depth of transformer. Default: 12.
+ num_heads (int): number of attention heads. Default: 12.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): dropout rate. Default: 0.
+ attn_drop_rate (float): attention dropout rate. Default: 0.
+ drop_path_rate (float): Rate of DropPath. Default: 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Default: bicubic.
+ with_cls_token (bool): If concatenating class token into image tokens
+ as transformer input. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ """
+
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=11,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ final_norm=False,
+ with_cls_token=True,
+ interpolate_mode='bicubic',
+ with_cp=False):
+ super(VisionTransformer, self).__init__()
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.features = self.embed_dim = embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim)
+
+ self.with_cls_token = with_cls_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if isinstance(out_indices, int):
+ self.out_indices = [out_indices]
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+ self.out_indices = out_indices
+ else:
+ raise TypeError('out_indices must be type of int, list or tuple')
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=dpr[i],
+ attn_drop=attn_drop_rate,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp) for i in range(depth)
+ ])
+
+ self.interpolate_mode = interpolate_mode
+ self.final_norm = final_norm
+ if final_norm:
+ _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(pretrained, logger=logger)
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ if 'pos_embed' in state_dict.keys():
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
+ logger.info(msg=f'Resize the pos_embed shape from \
+{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+ h, w = self.img_size
+ pos_size = int(
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+ state_dict['pos_embed'] = self.resize_pos_embed(
+ state_dict['pos_embed'], (h, w), (pos_size, pos_size),
+ self.patch_size, self.interpolate_mode)
+
+ self.load_state_dict(state_dict, False)
+
+ elif pretrained is None:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'mlp' in n:
+ normal_init(m.bias, std=1e-6)
+ else:
+ constant_init(m.bias, 0)
+ elif isinstance(m, Conv2d):
+ kaiming_init(m.weight, mode='fan_in')
+ if m.bias is not None:
+ constant_init(m.bias, 0)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m.bias, 0)
+ constant_init(m.weight, 1.0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def _pos_embeding(self, img, patched_img, pos_embed):
+ """Positiong embeding method.
+
+ Resize the pos_embed, if the input image size doesn't match
+ the training size.
+ Args:
+ img (torch.Tensor): The inference image tensor, the shape
+ must be [B, C, H, W].
+ patched_img (torch.Tensor): The patched image, it should be
+ shape of [B, L1, C].
+ pos_embed (torch.Tensor): The pos_embed weighs, it should be
+ shape of [B, L2, c].
+ Return:
+ torch.Tensor: The pos encoded image feature.
+ """
+ assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
+ 'the shapes of patched_img and pos_embed must be [B, L, C]'
+ x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
+ if x_len != pos_len:
+ if pos_len == (self.img_size[0] // self.patch_size) * (
+ self.img_size[1] // self.patch_size) + 1:
+ pos_h = self.img_size[0] // self.patch_size
+ pos_w = self.img_size[1] // self.patch_size
+ else:
+ raise ValueError(
+ 'Unexpected shape of pos_embed, got {}.'.format(
+ pos_embed.shape))
+ pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
+ (pos_h, pos_w), self.patch_size,
+ self.interpolate_mode)
+ return self.pos_drop(patched_img + pos_embed)
+
+ @staticmethod
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bicubic interpolate method.
+ Args:
+ pos_embed (torch.Tensor): pos_embed weights.
+ input_shpae (tuple): Tuple for (input_h, intput_w).
+ pos_shape (tuple): Tuple for (pos_h, pos_w).
+ patch_size (int): Patch size.
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ input_h, input_w = input_shpae
+ pos_h, pos_w = pos_shape
+ cls_token_weight = pos_embed[:, 0]
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight,
+ size=[input_h // patch_size, input_w // patch_size],
+ align_corners=False,
+ mode=mode)
+ cls_token_weight = cls_token_weight.unsqueeze(1)
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+ return pos_embed
+
+ def forward(self, inputs):
+ B = inputs.shape[0]
+
+ x = self.patch_embed(inputs)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self._pos_embeding(inputs, x, self.pos_embed)
+
+ if not self.with_cls_token:
+ # Remove class token for transformer input
+ x = x[:, 1:]
+
+ outs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i == len(self.blocks) - 1:
+ if self.final_norm:
+ x = self.norm(x)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ else:
+ out = x
+ B, _, C = out.shape
+ out = out.reshape(B, inputs.shape[2] // self.patch_size,
+ inputs.shape[3] // self.patch_size,
+ C).permute(0, 3, 1, 2)
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VisionTransformer, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.LayerNorm):
+ m.eval()
diff --git a/src/custom_mmpkg/custom_mmseg/models/builder.py b/src/custom_mmpkg/custom_mmseg/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e5920e9dec9d62e5a62ed688cab7d3bfd1ac74
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/builder.py
@@ -0,0 +1,46 @@
+import warnings
+
+from custom_mmpkg.custom_mmcv.cnn import MODELS as MMCV_MODELS
+from custom_mmpkg.custom_mmcv.utils import Registry
+
+MODELS = Registry('models', parent=MMCV_MODELS)
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+SEGMENTORS = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_segmentor(cfg, train_cfg=None, test_cfg=None):
+ """Build segmentor."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return SEGMENTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/__init__.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/__init__.py
@@ -0,0 +1,28 @@
+from .ann_head import ANNHead
+from .apc_head import APCHead
+from .aspp_head import ASPPHead
+from .cc_head import CCHead
+from .da_head import DAHead
+from .dm_head import DMHead
+from .dnl_head import DNLHead
+from .ema_head import EMAHead
+from .enc_head import EncHead
+from .fcn_head import FCNHead
+from .fpn_head import FPNHead
+from .gc_head import GCHead
+from .lraspp_head import LRASPPHead
+from .nl_head import NLHead
+from .ocr_head import OCRHead
+# from .point_head import PointHead
+from .psa_head import PSAHead
+from .psp_head import PSPHead
+from .sep_aspp_head import DepthwiseSeparableASPPHead
+from .sep_fcn_head import DepthwiseSeparableFCNHead
+from .uper_head import UPerHead
+
+__all__ = [
+ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
+ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
+ 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
+ 'APCHead', 'DMHead', 'LRASPPHead'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/ann_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ann_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..bebbc4f1ba6f76508a3f71265e519cbd24a509cc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ann_head.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PPMConcat(nn.ModuleList):
+ """Pyramid Pooling Module that only concat the features of each layer.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ """
+
+ def __init__(self, pool_scales=(1, 3, 6, 8)):
+ super(PPMConcat, self).__init__(
+ [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
+
+ def forward(self, feats):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(feats)
+ ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
+ concat_outs = torch.cat(ppm_outs, dim=2)
+ return concat_outs
+
+
+class SelfAttentionBlock(_SelfAttentionBlock):
+ """Make a ANN used SelfAttentionBlock.
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_scale (int): The scale of query feature map.
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, share_key_query, query_scale, key_pool_scales,
+ conv_cfg, norm_cfg, act_cfg):
+ key_psp = PPMConcat(key_pool_scales)
+ if query_scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=query_scale)
+ else:
+ query_downsample = None
+ super(SelfAttentionBlock, self).__init__(
+ key_in_channels=low_in_channels,
+ query_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=share_key_query,
+ query_downsample=query_downsample,
+ key_downsample=key_psp,
+ key_query_num_convs=1,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+
+class AFNB(nn.Module):
+ """Asymmetric Fusion Non-local Block(AFNB)
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ and query projection.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, query_scales, key_pool_scales, conv_cfg,
+ norm_cfg, act_cfg):
+ super(AFNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=False,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ out_channels + high_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, low_feats, high_feats):
+ """Forward function."""
+ priors = [stage(high_feats, low_feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, high_feats], 1))
+ return output
+
+
+class APNB(nn.Module):
+ """Asymmetric Pyramid Non-local Block (APNB)
+
+ Args:
+ in_channels (int): Input channels of key/query feature,
+ which is the key feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, channels, out_channels, query_scales,
+ key_pool_scales, conv_cfg, norm_cfg, act_cfg):
+ super(APNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=in_channels,
+ high_in_channels=in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=True,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ 2 * in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, feats):
+ """Forward function."""
+ priors = [stage(feats, feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, feats], 1))
+ return output
+
+
+@HEADS.register_module()
+class ANNHead(BaseDecodeHead):
+ """Asymmetric Non-local Neural Networks for Semantic Segmentation.
+
+ This head is the implementation of `ANNNet
+ `_.
+
+ Args:
+ project_channels (int): Projection channels for Nonlocal.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): The pooling scales of key feature map.
+ Default: (1, 3, 6, 8).
+ """
+
+ def __init__(self,
+ project_channels,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ **kwargs):
+ super(ANNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(self.in_channels) == 2
+ low_in_channels, high_in_channels = self.in_channels
+ self.project_channels = project_channels
+ self.fusion = AFNB(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ out_channels=high_in_channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ high_in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.context = APNB(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ low_feats, high_feats = self._transform_inputs(inputs)
+ output = self.fusion(low_feats, high_feats)
+ output = self.dropout(output)
+ output = self.bottleneck(output)
+ output = self.context(output)
+ output = self.cls_seg(output)
+
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/apc_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/apc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..119c083a3422b939615a2310d647993d31cb4dc0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/apc_head.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ACM(nn.Module):
+ """Adaptive Context Module used in APCNet.
+
+ Args:
+ pool_scale (int): Pooling scale used in Adaptive Context
+ Module to extract region features.
+ fusion (bool): Add one conv to fuse residual feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(ACM, self).__init__()
+ self.pool_scale = pool_scale
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.pooled_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.global_info = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
+
+ self.residual_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
+ # [batch_size, channels, h, w]
+ x = self.input_redu_conv(x)
+ # [batch_size, channels, pool_scale, pool_scale]
+ pooled_x = self.pooled_redu_conv(pooled_x)
+ batch_size = x.size(0)
+ # [batch_size, pool_scale * pool_scale, channels]
+ pooled_x = pooled_x.view(batch_size, self.channels,
+ -1).permute(0, 2, 1).contiguous()
+ # [batch_size, h * w, pool_scale * pool_scale]
+ affinity_matrix = self.gla(x + resize(
+ self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
+ ).permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.pool_scale**2)
+ affinity_matrix = F.sigmoid(affinity_matrix)
+ # [batch_size, h * w, channels]
+ z_out = torch.matmul(affinity_matrix, pooled_x)
+ # [batch_size, channels, h * w]
+ z_out = z_out.permute(0, 2, 1).contiguous()
+ # [batch_size, channels, h, w]
+ z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
+ z_out = self.residual_conv(z_out)
+ z_out = F.relu(z_out + x)
+ if self.fusion:
+ z_out = self.fusion_conv(z_out)
+
+ return z_out
+
+
+@HEADS.register_module()
+class APCHead(BaseDecodeHead):
+ """Adaptive Pyramid Context Network for Semantic Segmentation.
+
+ This head is the implementation of
+ `APCNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Adaptive Context
+ Module. Default: (1, 2, 3, 6).
+ fusion (bool): Add one conv to fuse residual feature.
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
+ super(APCHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.fusion = fusion
+ acm_modules = []
+ for pool_scale in self.pool_scales:
+ acm_modules.append(
+ ACM(pool_scale,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.acm_modules = nn.ModuleList(acm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ acm_outs = [x]
+ for acm_module in self.acm_modules:
+ acm_outs.append(acm_module(x))
+ acm_outs = torch.cat(acm_outs, dim=1)
+ output = self.bottleneck(acm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/aspp_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b251f2659b9800df341d214610f3766ef81a835
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/aspp_head.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ASPPModule(nn.ModuleList):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module.
+
+ Args:
+ dilations (tuple[int]): Dilation rate of each layer.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg):
+ super(ASPPModule, self).__init__()
+ self.dilations = dilations
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for dilation in dilations:
+ self.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1 if dilation == 1 else 3,
+ dilation=dilation,
+ padding=0 if dilation == 1 else dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, x):
+ """Forward function."""
+ aspp_outs = []
+ for aspp_module in self:
+ aspp_outs.append(aspp_module(x))
+
+ return aspp_outs
+
+
+@HEADS.register_module()
+class ASPPHead(BaseDecodeHead):
+ """Rethinking Atrous Convolution for Semantic Image Segmentation.
+
+ This head is the implementation of `DeepLabV3
+ `_.
+
+ Args:
+ dilations (tuple[int]): Dilation rates for ASPP module.
+ Default: (1, 6, 12, 18).
+ """
+
+ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
+ super(ASPPHead, self).__init__(**kwargs)
+ assert isinstance(dilations, (list, tuple))
+ self.dilations = dilations
+ self.image_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.aspp_modules = ASPPModule(
+ dilations,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ (len(dilations) + 1) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/cascade_decode_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/cascade_decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/cascade_decode_head.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+
+from .decode_head import BaseDecodeHead
+
+
+class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
+ """Base class for cascade decode head used in
+ :class:`CascadeEncoderDecoder."""
+
+ def __init__(self, *args, **kwargs):
+ super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
+
+ @abstractmethod
+ def forward(self, inputs, prev_output):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs, prev_output)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, prev_output)
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/cc_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/cc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a582718478dab5c55eec3de6bcf7ac842da25e8d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/cc_head.py
@@ -0,0 +1,42 @@
+import torch
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+try:
+ from custom_mmpkg.custom_mmcv.ops import CrissCrossAttention
+except ModuleNotFoundError:
+ CrissCrossAttention = None
+
+
+@HEADS.register_module()
+class CCHead(FCNHead):
+ """CCNet: Criss-Cross Attention for Semantic Segmentation.
+
+ This head is the implementation of `CCNet
+ `_.
+
+ Args:
+ recurrence (int): Number of recurrence of Criss Cross Attention
+ module. Default: 2.
+ """
+
+ def __init__(self, recurrence=2, **kwargs):
+ if CrissCrossAttention is None:
+ raise RuntimeError('Please install mmcv-full for '
+ 'CrissCrossAttention ops')
+ super(CCHead, self).__init__(num_convs=2, **kwargs)
+ self.recurrence = recurrence
+ self.cca = CrissCrossAttention(self.channels)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ for _ in range(self.recurrence):
+ output = self.cca(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/da_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/da_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce384a5f040e815c61ffd4a0e46d058fa874e11a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/da_head.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, Scale
+from torch import nn
+
+from custom_mmpkg.custom_mmseg.core import add_prefix
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PAM(_SelfAttentionBlock):
+ """Position Attention Module (PAM)
+
+ Args:
+ in_channels (int): Input channels of key/query feature.
+ channels (int): Output channels of key/query transform.
+ """
+
+ def __init__(self, in_channels, channels):
+ super(PAM, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=None,
+ key_downsample=None,
+ key_query_num_convs=1,
+ key_query_norm=False,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=False,
+ with_out=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None)
+
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ out = super(PAM, self).forward(x, x)
+
+ out = self.gamma(out) + x
+ return out
+
+
+class CAM(nn.Module):
+ """Channel Attention Module (CAM)"""
+
+ def __init__(self):
+ super(CAM, self).__init__()
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ batch_size, channels, height, width = x.size()
+ proj_query = x.view(batch_size, channels, -1)
+ proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
+ energy = torch.bmm(proj_query, proj_key)
+ energy_new = torch.max(
+ energy, -1, keepdim=True)[0].expand_as(energy) - energy
+ attention = F.softmax(energy_new, dim=-1)
+ proj_value = x.view(batch_size, channels, -1)
+
+ out = torch.bmm(attention, proj_value)
+ out = out.view(batch_size, channels, height, width)
+
+ out = self.gamma(out) + x
+ return out
+
+
+@HEADS.register_module()
+class DAHead(BaseDecodeHead):
+ """Dual Attention Network for Scene Segmentation.
+
+ This head is the implementation of `DANet
+ `_.
+
+ Args:
+ pam_channels (int): The channels of Position Attention Module(PAM).
+ """
+
+ def __init__(self, pam_channels, **kwargs):
+ super(DAHead, self).__init__(**kwargs)
+ self.pam_channels = pam_channels
+ self.pam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam = PAM(self.channels, pam_channels)
+ self.pam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ self.cam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam = CAM()
+ self.cam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ def pam_cls_seg(self, feat):
+ """PAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.pam_conv_seg(feat)
+ return output
+
+ def cam_cls_seg(self, feat):
+ """CAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.cam_conv_seg(feat)
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ pam_feat = self.pam_in_conv(x)
+ pam_feat = self.pam(pam_feat)
+ pam_feat = self.pam_out_conv(pam_feat)
+ pam_out = self.pam_cls_seg(pam_feat)
+
+ cam_feat = self.cam_in_conv(x)
+ cam_feat = self.cam(cam_feat)
+ cam_feat = self.cam_out_conv(cam_feat)
+ cam_out = self.cam_cls_seg(cam_feat)
+
+ feat_sum = pam_feat + cam_feat
+ pam_cam_out = self.cls_seg(feat_sum)
+
+ return pam_cam_out, pam_out, cam_out
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, only ``pam_cam`` is used."""
+ return self.forward(inputs)[0]
+
+ def losses(self, seg_logit, seg_label):
+ """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
+ pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
+ loss = dict()
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
+ 'pam_cam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
+ return loss
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/decode_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed88037dd0e2200d359a2e3dd40dc24ba40feeb
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/decode_head.py
@@ -0,0 +1,234 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import normal_init
+from custom_mmpkg.custom_mmcv.runner import auto_fp16, force_fp32
+
+from custom_mmpkg.custom_mmseg.core import build_pixel_sampler
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import build_loss
+from ..losses import accuracy
+
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.loss_decode = build_loss(loss_decode)
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+ @force_fp32(apply_to=('seg_logit', ))
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ seg_logit = resize(
+ input=seg_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.sampler is not None:
+ seg_weight = self.sampler.sample(seg_logit, seg_label)
+ else:
+ seg_weight = None
+ seg_label = seg_label.squeeze(1)
+ loss['loss_seg'] = self.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=seg_weight,
+ ignore_index=self.ignore_index)
+ loss['acc_seg'] = accuracy(seg_logit, seg_label)
+ return loss
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/dm_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/dm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..607cd3dd2219a7971319c84ad2383bca25306b3d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/dm_head.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class DCM(nn.Module):
+ """Dynamic Convolutional Module used in DMNet.
+
+ Args:
+ filter_size (int): The filter size of generated convolution kernel
+ used in Dynamic Convolutional Module.
+ fusion (bool): Add one conv to fuse DCM output feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(DCM, self).__init__()
+ self.filter_size = filter_size
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
+ 0)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.norm_cfg is not None:
+ self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
+ else:
+ self.norm = None
+ self.activate = build_activation_layer(self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ generated_filter = self.filter_gen_conv(
+ F.adaptive_avg_pool2d(x, self.filter_size))
+ x = self.input_redu_conv(x)
+ b, c, h, w = x.shape
+ # [1, b * c, h, w], c = self.channels
+ x = x.view(1, b * c, h, w)
+ # [b * c, 1, filter_size, filter_size]
+ generated_filter = generated_filter.view(b * c, 1, self.filter_size,
+ self.filter_size)
+ pad = (self.filter_size - 1) // 2
+ if (self.filter_size - 1) % 2 == 0:
+ p2d = (pad, pad, pad, pad)
+ else:
+ p2d = (pad + 1, pad, pad + 1, pad)
+ x = F.pad(input=x, pad=p2d, mode='constant', value=0)
+ # [1, b * c, h, w]
+ output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
+ # [b, c, h, w]
+ output = output.view(b, c, h, w)
+ if self.norm is not None:
+ output = self.norm(output)
+ output = self.activate(output)
+
+ if self.fusion:
+ output = self.fusion_conv(output)
+
+ return output
+
+
+@HEADS.register_module()
+class DMHead(BaseDecodeHead):
+ """Dynamic Multi-scale Filters for Semantic Segmentation.
+
+ This head is the implementation of
+ `DMNet `_.
+
+ Args:
+ filter_sizes (tuple[int]): The size of generated convolutional filters
+ used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
+ fusion (bool): Add one conv to fuse DCM output feature.
+ """
+
+ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
+ super(DMHead, self).__init__(**kwargs)
+ assert isinstance(filter_sizes, (list, tuple))
+ self.filter_sizes = filter_sizes
+ self.fusion = fusion
+ dcm_modules = []
+ for filter_size in self.filter_sizes:
+ dcm_modules.append(
+ DCM(filter_size,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.dcm_modules = nn.ModuleList(dcm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(filter_sizes) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ dcm_outs = [x]
+ for dcm_module in self.dcm_modules:
+ dcm_outs.append(dcm_module(x))
+ dcm_outs = torch.cat(dcm_outs, dim=1)
+ output = self.bottleneck(dcm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/dnl_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/dnl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed5b7c1936aa6114d0370625482b677db58a43e8
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/dnl_head.py
@@ -0,0 +1,131 @@
+import torch
+from custom_mmpkg.custom_mmcv.cnn import NonLocal2d
+from torch import nn
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+class DisentangledNonLocal2d(NonLocal2d):
+ """Disentangled Non-Local Blocks.
+
+ Args:
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self, *arg, temperature, **kwargs):
+ super().__init__(*arg, **kwargs)
+ self.temperature = temperature
+ self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ """Embedded gaussian with temperature."""
+
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight /= self.temperature
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def forward(self, x):
+ # x: [N, C, H, W]
+ n = x.size(0)
+
+ # g_x: [N, HxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ # subtract mean
+ theta_x -= theta_x.mean(dim=-2, keepdim=True)
+ phi_x -= phi_x.mean(dim=-1, keepdim=True)
+
+ pairwise_func = getattr(self, self.mode)
+ # pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # y: [N, HxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # y: [N, C, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ # unary_mask: [N, 1, HxW]
+ unary_mask = self.conv_mask(x)
+ unary_mask = unary_mask.view(n, 1, -1)
+ unary_mask = unary_mask.softmax(dim=-1)
+ # unary_x: [N, 1, C]
+ unary_x = torch.matmul(unary_mask, g_x)
+ # unary_x: [N, C, 1, 1]
+ unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
+ n, self.inter_channels, 1, 1)
+
+ output = x + self.conv_out(y + unary_x)
+
+ return output
+
+
+@HEADS.register_module()
+class DNLHead(FCNHead):
+ """Disentangled Non-Local Neural Networks.
+
+ This head is the implementation of `DNLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: False.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ temperature=0.05,
+ **kwargs):
+ super(DNLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.temperature = temperature
+ self.dnl_block = DisentangledNonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode,
+ temperature=self.temperature)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.dnl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/ema_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ema_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2279d53be90e3aee8e109eae47277f7c3266cef
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ema_head.py
@@ -0,0 +1,168 @@
+import math
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+def reduce_mean(tensor):
+ """Reduce mean when distributed training."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
+
+
+class EMAModule(nn.Module):
+ """Expectation Maximization Attention Module used in EMANet.
+
+ Args:
+ channels (int): Channels of the whole module.
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ """
+
+ def __init__(self, channels, num_bases, num_stages, momentum):
+ super(EMAModule, self).__init__()
+ assert num_stages >= 1, 'num_stages must be at least 1!'
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.momentum = momentum
+
+ bases = torch.zeros(1, channels, self.num_bases)
+ bases.normal_(0, math.sqrt(2. / self.num_bases))
+ # [1, channels, num_bases]
+ bases = F.normalize(bases, dim=1, p=2)
+ self.register_buffer('bases', bases)
+
+ def forward(self, feats):
+ """Forward function."""
+ batch_size, channels, height, width = feats.size()
+ # [batch_size, channels, height*width]
+ feats = feats.view(batch_size, channels, height * width)
+ # [batch_size, channels, num_bases]
+ bases = self.bases.repeat(batch_size, 1, 1)
+
+ with torch.no_grad():
+ for i in range(self.num_stages):
+ # [batch_size, height*width, num_bases]
+ attention = torch.einsum('bcn,bck->bnk', feats, bases)
+ attention = F.softmax(attention, dim=2)
+ # l1 norm
+ attention_normed = F.normalize(attention, dim=1, p=1)
+ # [batch_size, channels, num_bases]
+ bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+
+ feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
+ feats_recon = feats_recon.view(batch_size, channels, height, width)
+
+ if self.training:
+ bases = bases.mean(dim=0, keepdim=True)
+ bases = reduce_mean(bases)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+ self.bases = (1 -
+ self.momentum) * self.bases + self.momentum * bases
+
+ return feats_recon
+
+
+@HEADS.register_module()
+class EMAHead(BaseDecodeHead):
+ """Expectation Maximization Attention Networks for Semantic Segmentation.
+
+ This head is the implementation of `EMANet
+ `_.
+
+ Args:
+ ema_channels (int): EMA module channels
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer. Default: True
+ momentum (float): Momentum to update the base. Default: 0.1.
+ """
+
+ def __init__(self,
+ ema_channels,
+ num_bases,
+ num_stages,
+ concat_input=True,
+ momentum=0.1,
+ **kwargs):
+ super(EMAHead, self).__init__(**kwargs)
+ self.ema_channels = ema_channels
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.concat_input = concat_input
+ self.momentum = momentum
+ self.ema_module = EMAModule(self.ema_channels, self.num_bases,
+ self.num_stages, self.momentum)
+
+ self.ema_in_conv = ConvModule(
+ self.in_channels,
+ self.ema_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # project (0, inf) -> (-inf, inf)
+ self.ema_mid_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
+ for param in self.ema_mid_conv.parameters():
+ param.requires_grad = False
+
+ self.ema_out_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.bottleneck = ConvModule(
+ self.ema_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.ema_in_conv(x)
+ identity = feats
+ feats = self.ema_mid_conv(feats)
+ recon = self.ema_module(feats)
+ recon = F.relu(recon, inplace=True)
+ recon = self.ema_out_conv(recon)
+ output = F.relu(identity + recon, inplace=True)
+ output = self.bottleneck(output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/enc_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/enc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee8ecd7401ec9619eb2ac176d380e4e513294ea3
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/enc_head.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, build_norm_layer
+
+from custom_mmpkg.custom_mmseg.ops import Encoding, resize
+from ..builder import HEADS, build_loss
+from .decode_head import BaseDecodeHead
+
+
+class EncModule(nn.Module):
+ """Encoding Module used in EncNet.
+
+ Args:
+ in_channels (int): Input channels.
+ num_codes (int): Number of code words.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
+ super(EncModule, self).__init__()
+ self.encoding_project = ConvModule(
+ in_channels,
+ in_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ # TODO: resolve this hack
+ # change to 1d
+ if norm_cfg is not None:
+ encoding_norm_cfg = norm_cfg.copy()
+ if encoding_norm_cfg['type'] in ['BN', 'IN']:
+ encoding_norm_cfg['type'] += '1d'
+ else:
+ encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
+ '2d', '1d')
+ else:
+ # fallback to BN1d
+ encoding_norm_cfg = dict(type='BN1d')
+ self.encoding = nn.Sequential(
+ Encoding(channels=in_channels, num_codes=num_codes),
+ build_norm_layer(encoding_norm_cfg, num_codes)[1],
+ nn.ReLU(inplace=True))
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels), nn.Sigmoid())
+
+ def forward(self, x):
+ """Forward function."""
+ encoding_projection = self.encoding_project(x)
+ encoding_feat = self.encoding(encoding_projection).mean(dim=1)
+ batch_size, channels, _, _ = x.size()
+ gamma = self.fc(encoding_feat)
+ y = gamma.view(batch_size, channels, 1, 1)
+ output = F.relu_(x + x * y)
+ return encoding_feat, output
+
+
+@HEADS.register_module()
+class EncHead(BaseDecodeHead):
+ """Context Encoding for Semantic Segmentation.
+
+ This head is the implementation of `EncNet
+ `_.
+
+ Args:
+ num_codes (int): Number of code words. Default: 32.
+ use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
+ regularize the training. Default: True.
+ add_lateral (bool): Whether use lateral connection to fuse features.
+ Default: False.
+ loss_se_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
+ """
+
+ def __init__(self,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ loss_se_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=0.2),
+ **kwargs):
+ super(EncHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ self.use_se_loss = use_se_loss
+ self.add_lateral = add_lateral
+ self.num_codes = num_codes
+ self.bottleneck = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if add_lateral:
+ self.lateral_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the last one
+ self.lateral_convs.append(
+ ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.fusion = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.enc_module = EncModule(
+ self.channels,
+ num_codes=num_codes,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.use_se_loss:
+ self.loss_se_decode = build_loss(loss_se_decode)
+ self.se_layer = nn.Linear(self.channels, self.num_classes)
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ feat = self.bottleneck(inputs[-1])
+ if self.add_lateral:
+ laterals = [
+ resize(
+ lateral_conv(inputs[i]),
+ size=feat.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ feat = self.fusion(torch.cat([feat, *laterals], 1))
+ encode_feat, output = self.enc_module(feat)
+ output = self.cls_seg(output)
+ if self.use_se_loss:
+ se_output = self.se_layer(encode_feat)
+ return output, se_output
+ else:
+ return output
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, ignore se_loss."""
+ if self.use_se_loss:
+ return self.forward(inputs)[0]
+ else:
+ return self.forward(inputs)
+
+ @staticmethod
+ def _convert_to_onehot_labels(seg_label, num_classes):
+ """Convert segmentation label to onehot.
+
+ Args:
+ seg_label (Tensor): Segmentation label of shape (N, H, W).
+ num_classes (int): Number of classes.
+
+ Returns:
+ Tensor: Onehot labels of shape (N, num_classes).
+ """
+
+ batch_size = seg_label.size(0)
+ onehot_labels = seg_label.new_zeros((batch_size, num_classes))
+ for i in range(batch_size):
+ hist = seg_label[i].float().histc(
+ bins=num_classes, min=0, max=num_classes - 1)
+ onehot_labels[i] = hist > 0
+ return onehot_labels
+
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation and semantic encoding loss."""
+ seg_logit, se_seg_logit = seg_logit
+ loss = dict()
+ loss.update(super(EncHead, self).losses(seg_logit, seg_label))
+ se_loss = self.loss_se_decode(
+ se_seg_logit,
+ self._convert_to_onehot_labels(seg_label, self.num_classes))
+ loss['loss_se'] = se_loss
+ return loss
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/fcn_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f0c384381a1f1b26f795e2ed53c571823858317
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/fcn_head.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FCNHead(BaseDecodeHead):
+ """Fully Convolution Networks for Semantic Segmentation.
+
+ This head is implemented of `FCNNet `_.
+
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+ """
+
+ def __init__(self,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ dilation=1,
+ **kwargs):
+ assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ super(FCNHead, self).__init__(**kwargs)
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for i in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs(x)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/fpn_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d0849b1ace5911974437be5ae328e6107b44bc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/fpn_head.py
@@ -0,0 +1,68 @@
+import numpy as np
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FPNHead(BaseDecodeHead):
+ """Panoptic Feature Pyramid Networks.
+
+ This head is the implementation of `Semantic FPN
+ `_.
+
+ Args:
+ feature_strides (tuple[int]): The strides for input feature maps.
+ stack_lateral. All strides suppose to be power of 2. The first
+ one is of largest resolution.
+ """
+
+ def __init__(self, feature_strides, **kwargs):
+ super(FPNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+
+ self.scale_heads = nn.ModuleList()
+ for i in range(len(feature_strides)):
+ head_length = max(
+ 1,
+ int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
+ scale_head = []
+ for k in range(head_length):
+ scale_head.append(
+ ConvModule(
+ self.in_channels[i] if k == 0 else self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if feature_strides[i] != feature_strides[0]:
+ scale_head.append(
+ nn.Upsample(
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=self.align_corners))
+ self.scale_heads.append(nn.Sequential(*scale_head))
+
+ def forward(self, inputs):
+
+ x = self._transform_inputs(inputs)
+
+ output = self.scale_heads[0](x[0])
+ for i in range(1, len(self.feature_strides)):
+ # non inplace
+ output = output + resize(
+ self.scale_heads[i](x[i]),
+ size=output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/gc_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/gc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..600049998d04fc5f469e8da41243bb2d51b64cc1
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/gc_head.py
@@ -0,0 +1,47 @@
+import torch
+from custom_mmpkg.custom_mmcv.cnn import ContextBlock
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class GCHead(FCNHead):
+ """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
+
+ This head is the implementation of `GCNet
+ `_.
+
+ Args:
+ ratio (float): Multiplier of channels ratio. Default: 1/4.
+ pooling_type (str): The pooling type of context aggregation.
+ Options are 'att', 'avg'. Default: 'avg'.
+ fusion_types (tuple[str]): The fusion type for feature fusion.
+ Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
+ """
+
+ def __init__(self,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ **kwargs):
+ super(GCHead, self).__init__(num_convs=2, **kwargs)
+ self.ratio = ratio
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ self.gc_block = ContextBlock(
+ in_channels=self.channels,
+ ratio=self.ratio,
+ pooling_type=self.pooling_type,
+ fusion_types=self.fusion_types)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.gc_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/lraspp_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/lraspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5395a8f57fdbdf6828842f2c4c9a291ecf0b2cdc
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/lraspp_head.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv import is_tuple_of
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class LRASPPHead(BaseDecodeHead):
+ """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
+
+ This head is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ branch_channels (tuple[int]): The number of output channels in every
+ each branch. Default: (32, 64).
+ """
+
+ def __init__(self, branch_channels=(32, 64), **kwargs):
+ super(LRASPPHead, self).__init__(**kwargs)
+ if self.input_transform != 'multiple_select':
+ raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
+ f'must be \'multiple_select\'. But received '
+ f'\'{self.input_transform}\'')
+ assert is_tuple_of(branch_channels, int)
+ assert len(branch_channels) == len(self.in_channels) - 1
+ self.branch_channels = branch_channels
+
+ self.convs = nn.Sequential()
+ self.conv_ups = nn.Sequential()
+ for i in range(len(branch_channels)):
+ self.convs.add_module(
+ f'conv{i}',
+ nn.Conv2d(
+ self.in_channels[i], branch_channels[i], 1, bias=False))
+ self.conv_ups.add_module(
+ f'conv_up{i}',
+ ConvModule(
+ self.channels + branch_channels[i],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False))
+
+ self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
+
+ self.aspp_conv = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False)
+ self.image_pool = nn.Sequential(
+ nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
+ ConvModule(
+ self.in_channels[2],
+ self.channels,
+ 1,
+ act_cfg=dict(type='Sigmoid'),
+ bias=False))
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+
+ x = inputs[-1]
+
+ x = self.aspp_conv(x) * resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = self.conv_up_input(x)
+
+ for i in range(len(self.branch_channels) - 1, -1, -1):
+ x = resize(
+ x,
+ size=inputs[i].size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = torch.cat([x, self.convs[i](inputs[i])], 1)
+ x = self.conv_ups[i](x)
+
+ return self.cls_seg(x)
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/nl_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/nl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9010d303cb3808de4893c3900ebbd9917a9cc57e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/nl_head.py
@@ -0,0 +1,49 @@
+import torch
+from custom_mmpkg.custom_mmcv.cnn import NonLocal2d
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class NLHead(FCNHead):
+ """Non-local Neural Networks.
+
+ This head is the implementation of `NLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: True.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(NLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.nl_block = NonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.nl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/ocr_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ocr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b31f00233355ea610ea61a4f40cc3dfb6d84c8b7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/ocr_head.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+class SpatialGatherModule(nn.Module):
+ """Aggregate the context features according to the initial predicted
+ probability distribution.
+
+ Employ the soft-weighted method to aggregate the context.
+ """
+
+ def __init__(self, scale):
+ super(SpatialGatherModule, self).__init__()
+ self.scale = scale
+
+ def forward(self, feats, probs):
+ """Forward function."""
+ batch_size, num_classes, height, width = probs.size()
+ channels = feats.size(1)
+ probs = probs.view(batch_size, num_classes, -1)
+ feats = feats.view(batch_size, channels, -1)
+ # [batch_size, height*width, num_classes]
+ feats = feats.permute(0, 2, 1)
+ # [batch_size, channels, height*width]
+ probs = F.softmax(self.scale * probs, dim=2)
+ # [batch_size, channels, num_classes]
+ ocr_context = torch.matmul(probs, feats)
+ ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
+ return ocr_context
+
+
+class ObjectAttentionBlock(_SelfAttentionBlock):
+ """Make a OCR used SelfAttentionBlock."""
+
+ def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
+ act_cfg):
+ if scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=scale)
+ else:
+ query_downsample = None
+ super(ObjectAttentionBlock, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=query_downsample,
+ key_downsample=None,
+ key_query_num_convs=2,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=True,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.bottleneck = ConvModule(
+ in_channels * 2,
+ in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ context = super(ObjectAttentionBlock,
+ self).forward(query_feats, key_feats)
+ output = self.bottleneck(torch.cat([context, query_feats], dim=1))
+ if self.query_downsample is not None:
+ output = resize(query_feats)
+
+ return output
+
+
+@HEADS.register_module()
+class OCRHead(BaseCascadeDecodeHead):
+ """Object-Contextual Representations for Semantic Segmentation.
+
+ This head is the implementation of `OCRNet
+ `_.
+
+ Args:
+ ocr_channels (int): The intermediate channels of OCR block.
+ scale (int): The scale of probability map in SpatialGatherModule in
+ Default: 1.
+ """
+
+ def __init__(self, ocr_channels, scale=1, **kwargs):
+ super(OCRHead, self).__init__(**kwargs)
+ self.ocr_channels = ocr_channels
+ self.scale = scale
+ self.object_context_block = ObjectAttentionBlock(
+ self.channels,
+ self.ocr_channels,
+ self.scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.spatial_gather_module = SpatialGatherModule(self.scale)
+
+ self.bottleneck = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs, prev_output):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.bottleneck(x)
+ context = self.spatial_gather_module(feats, prev_output)
+ object_context = self.object_context_block(feats, context)
+ output = self.cls_seg(object_context)
+
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/point_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c9f8e078536a733616a9f33de2dd35a8adbc26
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/point_head.py
@@ -0,0 +1,350 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+
+import torch
+import torch.nn as nn
+
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, normal_init
+from custom_mmpkg.custom_mmcv.ops import point_sample
+
+from custom_mmpkg.custom_mmseg.models.builder import HEADS
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..losses import accuracy
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+def calculate_uncertainty(seg_logits):
+ """Estimate uncertainty based on seg logits.
+
+ For each location of the prediction ``seg_logits`` we estimate
+ uncertainty as the difference between top first and top second
+ predicted logits.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits,
+ shape (batch_size, num_classes, height, width).
+
+ Returns:
+ scores (Tensor): T uncertainty scores with the most uncertain
+ locations having the highest uncertainty score, shape (
+ batch_size, 1, height, width)
+ """
+ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
+ return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+
+
+@HEADS.register_module()
+class PointHead(BaseCascadeDecodeHead):
+ """A mask point head use in PointRend.
+
+ ``PointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict|None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict|None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU', inplace=False),
+ **kwargs):
+ super(PointHead, self).__init__(
+ input_transform='multiple_select',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+
+ self.num_fcs = num_fcs
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+
+ fc_in_channels = sum(self.in_channels) + self.num_classes
+ fc_channels = self.channels
+ self.fcs = nn.ModuleList()
+ for k in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
+ else 0
+ self.fc_seg = nn.Conv1d(
+ fc_in_channels,
+ self.num_classes,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ if self.dropout_ratio > 0:
+ self.dropout = nn.Dropout(self.dropout_ratio)
+ delattr(self, 'conv_seg')
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.fc_seg, std=0.001)
+
+ def cls_seg(self, feat):
+ """Classify each pixel with fc."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.fc_seg(feat)
+ return output
+
+ def forward(self, fine_grained_point_feats, coarse_point_feats):
+ x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_point_feats), dim=1)
+ return self.cls_seg(x)
+
+ def _get_fine_grained_point_feats(self, x, points):
+ """Sample from fine grained features.
+
+ Args:
+ x (list[Tensor]): Feature pyramid from by neck or backbone.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ fine_grained_feats (Tensor): Sampled fine grained feature,
+ shape (batch_size, sum(channels of x), num_points).
+ """
+
+ fine_grained_feats_list = [
+ point_sample(_, points, align_corners=self.align_corners)
+ for _ in x
+ ]
+ if len(fine_grained_feats_list) > 1:
+ fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
+ else:
+ fine_grained_feats = fine_grained_feats_list[0]
+
+ return fine_grained_feats
+
+ def _get_coarse_point_feats(self, prev_output, points):
+ """Sample from fine grained features.
+
+ Args:
+ prev_output (list[Tensor]): Prediction of previous decode head.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
+ num_classes, num_points).
+ """
+
+ coarse_feats = point_sample(
+ prev_output, points, align_corners=self.align_corners)
+
+ return coarse_feats
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self._transform_inputs(inputs)
+ with torch.no_grad():
+ points = self.get_points_train(
+ prev_output, calculate_uncertainty, cfg=train_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+ point_label = point_sample(
+ gt_semantic_seg.float(),
+ points,
+ mode='nearest',
+ align_corners=self.align_corners)
+ point_label = point_label.squeeze(1).long()
+
+ losses = self.losses(point_logits, point_label)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+
+ x = self._transform_inputs(inputs)
+ refined_seg_logits = prev_output.clone()
+ for _ in range(test_cfg.subdivision_steps):
+ refined_seg_logits = resize(
+ refined_seg_logits,
+ scale_factor=test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ batch_size, channels, height, width = refined_seg_logits.shape
+ point_indices, points = self.get_points_test(
+ refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(
+ prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_seg_logits = refined_seg_logits.reshape(
+ batch_size, channels, height * width)
+ refined_seg_logits = refined_seg_logits.scatter_(
+ 2, point_indices, point_logits)
+ refined_seg_logits = refined_seg_logits.view(
+ batch_size, channels, height, width)
+
+ return refined_seg_logits
+
+ def losses(self, point_logits, point_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ loss['loss_point'] = self.loss_decode(
+ point_logits, point_label, ignore_index=self.ignore_index)
+ loss['acc_point'] = accuracy(point_logits, point_label)
+ return loss
+
+ def get_points_train(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for training.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'uncertainty_func' function that takes point's logit prediction as
+ input.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits, shape (
+ batch_size, num_classes, height, width).
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Training config of point head.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains the coordinates of ``num_points`` sampled
+ points.
+ """
+ num_points = cfg.num_points
+ oversample_ratio = cfg.oversample_ratio
+ importance_sample_ratio = cfg.importance_sample_ratio
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = seg_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=seg_logits.device)
+ point_logits = point_sample(seg_logits, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=seg_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_point_coords = torch.rand(
+ batch_size, num_random_points, 2, device=seg_logits.device)
+ point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
+ return point_coords
+
+ def get_points_test(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for testing.
+
+ Find ``num_points`` most uncertain points from ``uncertainty_map``.
+
+ Args:
+ seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
+ height, width) for class-specific or class-agnostic prediction.
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Testing config of point head.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (batch_size, num_points)
+ that contains indices from [0, height x width) of the most
+ uncertain points.
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the ``height x width`` grid .
+ """
+
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = uncertainty_func(seg_logits)
+ batch_size, _, height, width = uncertainty_map.shape
+ h_step = 1.0 / height
+ w_step = 1.0 / width
+
+ uncertainty_map = uncertainty_map.view(batch_size, height * width)
+ num_points = min(height * width, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ point_coords = torch.zeros(
+ batch_size,
+ num_points,
+ 2,
+ dtype=torch.float,
+ device=seg_logits.device)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+ width).float() * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+ width).float() * h_step
+ return point_indices, point_coords
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/psa_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/psa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1424a85d7ca93eb299ad1e9116600f703940fc9
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/psa_head.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+try:
+ from custom_mmpkg.custom_mmcv.ops import PSAMask
+except ModuleNotFoundError:
+ PSAMask = None
+
+
+@HEADS.register_module()
+class PSAHead(BaseDecodeHead):
+ """Point-wise Spatial Attention Network for Scene Parsing.
+
+ This head is the implementation of `PSANet
+ `_.
+
+ Args:
+ mask_size (tuple[int]): The PSA mask size. It usually equals input
+ size.
+ psa_type (str): The type of psa module. Options are 'collect',
+ 'distribute', 'bi-direction'. Default: 'bi-direction'
+ compact (bool): Whether use compact map for 'collect' mode.
+ Default: True.
+ shrink_factor (int): The downsample factors of psa mask. Default: 2.
+ normalization_factor (float): The normalize factor of attention.
+ psa_softmax (bool): Whether use softmax for attention.
+ """
+
+ def __init__(self,
+ mask_size,
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ **kwargs):
+ if PSAMask is None:
+ raise RuntimeError('Please install mmcv-full for PSAMask ops')
+ super(PSAHead, self).__init__(**kwargs)
+ assert psa_type in ['collect', 'distribute', 'bi-direction']
+ self.psa_type = psa_type
+ self.compact = compact
+ self.shrink_factor = shrink_factor
+ self.mask_size = mask_size
+ mask_h, mask_w = mask_size
+ self.psa_softmax = psa_softmax
+ if normalization_factor is None:
+ normalization_factor = mask_h * mask_w
+ self.normalization_factor = normalization_factor
+
+ self.reduce = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ if psa_type == 'bi-direction':
+ self.reduce_p = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention_p = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ self.psamask_collect = PSAMask('collect', mask_size)
+ self.psamask_distribute = PSAMask('distribute', mask_size)
+ else:
+ self.psamask = PSAMask(psa_type, mask_size)
+ self.proj = ConvModule(
+ self.channels * (2 if psa_type == 'bi-direction' else 1),
+ self.in_channels,
+ kernel_size=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ self.in_channels * 2,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ identity = x
+ align_corners = self.align_corners
+ if self.psa_type in ['collect', 'distribute']:
+ out = self.reduce(x)
+ n, c, h, w = out.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ out = resize(
+ out,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y = self.attention(out)
+ if self.compact:
+ if self.psa_type == 'collect':
+ y = y.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y = self.psamask(y)
+ if self.psa_softmax:
+ y = F.softmax(y, dim=1)
+ out = torch.bmm(
+ out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ else:
+ x_col = self.reduce(x)
+ x_dis = self.reduce_p(x)
+ n, c, h, w = x_col.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ x_col = resize(
+ x_col,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ x_dis = resize(
+ x_dis,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y_col = self.attention(x_col)
+ y_dis = self.attention_p(x_dis)
+ if self.compact:
+ y_dis = y_dis.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y_col = self.psamask_collect(y_col)
+ y_dis = self.psamask_distribute(y_dis)
+ if self.psa_softmax:
+ y_col = F.softmax(y_col, dim=1)
+ y_dis = F.softmax(y_dis, dim=1)
+ x_col = torch.bmm(
+ x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ x_dis = torch.bmm(
+ x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ out = torch.cat([x_col, x_dis], 1)
+ out = self.proj(out)
+ out = resize(
+ out,
+ size=identity.shape[2:],
+ mode='bilinear',
+ align_corners=align_corners)
+ out = self.bottleneck(torch.cat((identity, out), dim=1))
+ out = self.cls_seg(out)
+ return out
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/psp_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/psp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f7880f21319ac6035f604b0fc54f2237f7ed988
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/psp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class PPM(nn.ModuleList):
+ """Pooling Pyramid Module used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ """
+
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg, align_corners):
+ super(PPM, self).__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for pool_scale in pool_scales:
+ self.append(
+ nn.Sequential(
+ nn.AdaptiveAvgPool2d(pool_scale),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)))
+
+ def forward(self, x):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = resize(
+ ppm_out,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+@HEADS.register_module()
+class PSPHead(BaseDecodeHead):
+ """Pyramid Scene Parsing Network.
+
+ This head is the implementation of
+ `PSPNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(PSPHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.psp_modules = PPM(
+ self.pool_scales,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_aspp_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..26942ae135d172ec2dbb3775c0dc548c6976e729
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_aspp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .aspp_head import ASPPHead, ASPPModule
+
+
+class DepthwiseSeparableASPPModule(ASPPModule):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
+ conv."""
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
+ for i, dilation in enumerate(self.dilations):
+ if dilation > 1:
+ self[i] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ dilation=dilation,
+ padding=dilation,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+
+@HEADS.register_module()
+class DepthwiseSeparableASPPHead(ASPPHead):
+ """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+ Segmentation.
+
+ This head is the implementation of `DeepLabV3+
+ `_.
+
+ Args:
+ c1_in_channels (int): The input channels of c1 decoder. If is 0,
+ the no decoder will be used.
+ c1_channels (int): The intermediate channels of c1 decoder.
+ """
+
+ def __init__(self, c1_in_channels, c1_channels, **kwargs):
+ super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
+ assert c1_in_channels >= 0
+ self.aspp_modules = DepthwiseSeparableASPPModule(
+ dilations=self.dilations,
+ in_channels=self.in_channels,
+ channels=self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if c1_in_channels > 0:
+ self.c1_bottleneck = ConvModule(
+ c1_in_channels,
+ c1_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ else:
+ self.c1_bottleneck = None
+ self.sep_bottleneck = nn.Sequential(
+ DepthwiseSeparableConvModule(
+ self.channels + c1_channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ if self.c1_bottleneck is not None:
+ c1_output = self.c1_bottleneck(inputs[0])
+ output = resize(
+ input=output,
+ size=c1_output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ output = torch.cat([output, c1_output], dim=1)
+ output = self.sep_bottleneck(output)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_fcn_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabb624a530098e44ed1d9a9a7762addeaa126d6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/sep_fcn_head.py
@@ -0,0 +1,51 @@
+from custom_mmpkg.custom_mmcv.cnn import DepthwiseSeparableConvModule
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class DepthwiseSeparableFCNHead(FCNHead):
+ """Depthwise-Separable Fully Convolutional Network for Semantic
+ Segmentation.
+
+ This head is implemented according to Fast-SCNN paper.
+ Args:
+ in_channels(int): Number of output channels of FFM.
+ channels(int): Number of middle-stage channels in the decode head.
+ concat_input(bool): Whether to concatenate original decode input into
+ the result of several consecutive convolution layers.
+ Default: True.
+ num_classes(int): Used to determine the dimension of
+ final prediction tensor.
+ in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
+ norm_cfg (dict | None): Config of norm layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_decode(dict): Config of loss type and some
+ relevant additional options.
+ """
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
+ self.convs[0] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+ for i in range(1, self.num_convs):
+ self.convs[i] = DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+
+ if self.concat_input:
+ self.conv_cat = DepthwiseSeparableConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
diff --git a/src/custom_mmpkg/custom_mmseg/models/decode_heads/uper_head.py b/src/custom_mmpkg/custom_mmseg/models/decode_heads/uper_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4990010074568484b8ea768bbca8e43d407659a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/decode_heads/uper_head.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from custom_mmpkg.custom_mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+from .psp_head import PPM
+
+
+@HEADS.register_module()
+class UPerHead(BaseDecodeHead):
+ """Unified Perceptual Parsing for Scene Understanding.
+
+ This head is the implementation of `UPerNet
+ `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module applied on the last feature. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(UPerHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ # PSP Module
+ self.psp_modules = PPM(
+ pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels[-1] + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def psp_forward(self, inputs):
+ """Forward function of PSP module."""
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+
+ inputs = self._transform_inputs(inputs)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ laterals.append(self.psp_forward(inputs))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += resize(
+ laterals[i],
+ size=prev_shape,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ # build outputs
+ fpn_outs = [
+ self.fpn_convs[i](laterals[i])
+ for i in range(used_backbone_levels - 1)
+ ]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = resize(
+ fpn_outs[i],
+ size=fpn_outs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/__init__.py b/src/custom_mmpkg/custom_mmseg/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/__init__.py
@@ -0,0 +1,12 @@
+from .accuracy import Accuracy, accuracy
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .lovasz_loss import LovaszLoss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
+ 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/accuracy.py b/src/custom_mmpkg/custom_mmseg/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+
+
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == target.ndim + 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ # transpose to shape (maxk, N, ...)
+ pred_label = pred_label.transpose(0, 1)
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / target.numel()))
+ return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+ """Accuracy calculation module."""
+
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/cross_entropy_loss.py b/src/custom_mmpkg/custom_mmseg/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/cross_entropy_loss.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=-100):
+ """The wrapper function for :func:`F.cross_entropy`"""
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=255):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored. Default: 255
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (
+ pred.dim() == 4 and label.dim() == 3), \
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+ 'H, W], label shape [N, H, W] are supported'
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
+ ignore_index)
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/dice_loss.py b/src/custom_mmpkg/custom_mmseg/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/dice_loss.py
@@ -0,0 +1,119 @@
+"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
+segmentron/solver/loss.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weighted_loss
+
+
+@weighted_loss
+def dice_loss(pred,
+ target,
+ valid_mask,
+ smooth=1,
+ exponent=2,
+ class_weight=None,
+ ignore_index=255):
+ assert pred.shape[0] == target.shape[0]
+ total_loss = 0
+ num_classes = pred.shape[1]
+ for i in range(num_classes):
+ if i != ignore_index:
+ dice_loss = binary_dice_loss(
+ pred[:, i],
+ target[..., i],
+ valid_mask=valid_mask,
+ smooth=smooth,
+ exponent=exponent)
+ if class_weight is not None:
+ dice_loss *= class_weight[i]
+ total_loss += dice_loss
+ return total_loss / num_classes
+
+
+@weighted_loss
+def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
+ assert pred.shape[0] == target.shape[0]
+ pred = pred.reshape(pred.shape[0], -1)
+ target = target.reshape(target.shape[0], -1)
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
+ den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+
+ return 1 - num / den
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+ """DiceLoss.
+
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+ Volumetric Medical Image Segmentation `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ smooth (float): A float number to smooth loss, and avoid NaN error.
+ Default: 1
+ exponent (float): An float number to calculate denominator
+ value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ """
+
+ def __init__(self,
+ smooth=1,
+ exponent=2,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0,
+ ignore_index=255,
+ **kwards):
+ super(DiceLoss, self).__init__()
+ self.smooth = smooth
+ self.exponent = exponent
+ self.reduction = reduction
+ self.class_weight = get_class_weight(class_weight)
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+
+ def forward(self,
+ pred,
+ target,
+ avg_factor=None,
+ reduction_override=None,
+ **kwards):
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = pred.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ pred = F.softmax(pred, dim=1)
+ num_classes = pred.shape[1]
+ one_hot_target = F.one_hot(
+ torch.clamp(target.long(), 0, num_classes - 1),
+ num_classes=num_classes)
+ valid_mask = (target != self.ignore_index).long()
+
+ loss = self.loss_weight * dice_loss(
+ pred,
+ one_hot_target,
+ valid_mask=valid_mask,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ smooth=self.smooth,
+ exponent=self.exponent,
+ class_weight=class_weight,
+ ignore_index=self.ignore_index)
+ return loss
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/lovasz_loss.py b/src/custom_mmpkg/custom_mmseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e381378522673d41a4ce2b9b9d6d70b9b3102bb0
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/lovasz_loss.py
@@ -0,0 +1,303 @@
+"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
+ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
+Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
+
+import custom_mmpkg.custom_mmcv as mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def lovasz_grad(gt_sorted):
+ """Computes gradient of the Lovasz extension w.r.t sorted errors.
+
+ See Alg. 1 in paper.
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1. - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def flatten_binary_logits(logits, labels, ignore_index=None):
+ """Flattens predictions in the batch (binary case) Remove labels equal to
+ 'ignore_index'."""
+ logits = logits.view(-1)
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return logits, labels
+ valid = (labels != ignore_index)
+ vlogits = logits[valid]
+ vlabels = labels[valid]
+ return vlogits, vlabels
+
+
+def flatten_probs(probs, labels, ignore_index=None):
+ """Flattens predictions in the batch."""
+ if probs.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probs.size()
+ probs = probs.view(B, 1, H, W)
+ B, C, H, W = probs.size()
+ probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return probs, labels
+ valid = (labels != ignore_index)
+ vprobs = probs[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobs, vlabels
+
+
+def lovasz_hinge_flat(logits, labels):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [P], logits at each prediction
+ (between -infty and +infty).
+ labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels.float() - 1.
+ errors = (1. - logits * signs)
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), grad)
+ return loss
+
+
+def lovasz_hinge(logits,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [B, H, W], logits at each pixel
+ (between -infty and +infty).
+ labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
+ classes (str | list[int], optional): Placeholder, to be consistent with
+ other loss. Default: None.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): Placeholder, to be consistent
+ with other loss. Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if per_image:
+ loss = [
+ lovasz_hinge_flat(*flatten_binary_logits(
+ logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
+ for logit, label in zip(logits, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_hinge_flat(
+ *flatten_binary_logits(logits, labels, ignore_index))
+ return loss
+
+
+def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [P, C], class probabilities at each prediction
+ (between 0 and 1).
+ labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if probs.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probs * 0.
+ C = probs.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+ for c in class_to_sum:
+ fg = (labels == c).float() # foreground for class c
+ if (classes == 'present' and fg.sum() == 0):
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probs[:, 0]
+ else:
+ class_pred = probs[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
+ if class_weight is not None:
+ loss *= class_weight[c]
+ losses.append(loss)
+ return torch.stack(losses).mean()
+
+
+def lovasz_softmax(probs,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [B, C, H, W], class probabilities at each
+ prediction (between 0 and 1).
+ labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
+ C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+
+ if per_image:
+ loss = [
+ lovasz_softmax_flat(
+ *flatten_probs(
+ prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ for prob, label in zip(probs, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_softmax_flat(
+ *flatten_probs(probs, labels, ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ return loss
+
+
+@LOSSES.register_module()
+class LovaszLoss(nn.Module):
+ """LovaszLoss.
+
+ This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
+ for the optimization of the intersection-over-union measure in neural
+ networks `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ loss_type='multi_class',
+ classes='present',
+ per_image=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(LovaszLoss, self).__init__()
+ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
+ 'binary' or 'multi_class'."
+
+ if loss_type == 'binary':
+ self.cls_criterion = lovasz_hinge
+ else:
+ self.cls_criterion = lovasz_softmax
+ assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
+ if not per_image:
+ assert reduction == 'none', "reduction should be 'none' when \
+ per_image is False."
+
+ self.classes = classes
+ self.per_image = per_image
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ # if multi-class loss, transform logits to probs
+ if self.cls_criterion == lovasz_softmax:
+ cls_score = F.softmax(cls_score, dim=1)
+
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ self.classes,
+ self.per_image,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/src/custom_mmpkg/custom_mmseg/models/losses/utils.py b/src/custom_mmpkg/custom_mmseg/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdfbd436d2305f82af065a853e789d6fa37614cd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/losses/utils.py
@@ -0,0 +1,121 @@
+import functools
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+import torch.nn.functional as F
+
+
+def get_class_weight(class_weight):
+ """Get class weight for loss function.
+
+ Args:
+ class_weight (list[float] | str | None): If class_weight is a str,
+ take it as a file name and read from it.
+ """
+ if isinstance(class_weight, str):
+ # take it as a file path
+ if class_weight.endswith('.npy'):
+ class_weight = np.load(class_weight)
+ else:
+ # pkl, json or yaml
+ class_weight = mmcv.load(class_weight)
+
+ return class_weight
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ if weight.dim() > 1:
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
diff --git a/src/custom_mmpkg/custom_mmseg/models/necks/__init__.py b/src/custom_mmpkg/custom_mmseg/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/necks/__init__.py
@@ -0,0 +1,4 @@
+from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
+
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/src/custom_mmpkg/custom_mmseg/models/necks/fpn.py b/src/custom_mmpkg/custom_mmseg/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c32cc5e44f92e7779ba1ba913c2482107e5900d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/necks/fpn.py
@@ -0,0 +1,212 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, xavier_init
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(nn.Module):
+ """Feature Pyramid Network.
+
+ This is an implementation of - Feature Pyramid Networks for Object
+ Detection (https://arxiv.org/abs/1612.03144)
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+ on the original feature from the backbone. If True,
+ it is equivalent to `add_extra_convs='on_input'`. If False, it is
+ equivalent to set `add_extra_convs='on_output'`. Default to True.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # For compatibility with previous release
+ # TODO: deprecate `extra_convs_on_inputs`
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/necks/multilevel_neck.py b/src/custom_mmpkg/custom_mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5e8563f7d7b27944fbad2c247f789967433bba
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class MultiLevelNeck(nn.Module):
+ """MultiLevelNeck.
+
+ A neck structure connect vit backbone and decoder_heads.
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ scales (List[int]): Scale factors for each input feature map.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scales=[0.5, 1, 2, 4],
+ norm_cfg=None,
+ act_cfg=None):
+ super(MultiLevelNeck, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scales = scales
+ self.num_outs = len(scales)
+ self.lateral_convs = nn.ModuleList()
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.lateral_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ for _ in range(self.num_outs):
+ self.convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ print(inputs[0].shape)
+ inputs = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # for len(inputs) not equal to self.num_outs
+ if len(inputs) == 1:
+ inputs = [inputs[0] for _ in range(self.num_outs)]
+ outs = []
+ for i in range(self.num_outs):
+ x_resize = F.interpolate(
+ inputs[i], scale_factor=self.scales[i], mode='bilinear')
+ outs.append(self.convs[i](x_resize))
+ return tuple(outs)
diff --git a/src/custom_mmpkg/custom_mmseg/models/segmentors/__init__.py b/src/custom_mmpkg/custom_mmseg/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/segmentors/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseSegmentor
+from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .encoder_decoder import EncoderDecoder
+
+__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
diff --git a/src/custom_mmpkg/custom_mmseg/models/segmentors/base.py b/src/custom_mmpkg/custom_mmseg/models/segmentors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd073f4a9d7713f02b107b7ad541384c3d27d6b
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/segmentors/base.py
@@ -0,0 +1,273 @@
+import logging
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import custom_mmpkg.custom_mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.runner import auto_fp16
+
+
+class BaseSegmentor(nn.Module):
+ """Base class for segmentors."""
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(BaseSegmentor, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the segmentor has neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the segmentor has auxiliary head"""
+ return hasattr(self,
+ 'auxiliary_head') and self.auxiliary_head is not None
+
+ @property
+ def with_decode_head(self):
+ """bool: whether the segmentor has decode head"""
+ return hasattr(self, 'decode_head') and self.decode_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic segmentation map of the same size as input."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in segmentor.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info(f'load model from: {pretrained}')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got '
+ f'{type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) != '
+ f'num of image meta ({len(img_metas)})')
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_['ori_shape'] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_['img_shape'] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_['pad_shape'] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data_batch['img_metas']))
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def show_result(self,
+ img,
+ result,
+ palette=None,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ opacity=0.5):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor): The semantic segmentation results to draw over
+ `img`.
+ palette (list[list[int]]] | np.ndarray | None): The palette of
+ segmentation map. If None is given, random palette will be
+ generated. Default: None
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ seg = result[0]
+ if palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(
+ 0, 255, size=(len(self.CLASSES), 3))
+ else:
+ palette = self.PALETTE
+ palette = np.array(palette)
+ assert palette.shape[0] == len(self.CLASSES)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ assert 0 < opacity <= 1.0
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+
+ img = img * (1 - opacity) + color_seg * opacity
+ img = img.astype(np.uint8)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
diff --git a/src/custom_mmpkg/custom_mmseg/models/segmentors/cascade_encoder_decoder.py b/src/custom_mmpkg/custom_mmseg/models/segmentors/cascade_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdece0ac89713a0cead10d98959bad2e2d05c4e7
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/segmentors/cascade_encoder_decoder.py
@@ -0,0 +1,98 @@
+from torch import nn
+
+from custom_mmpkg.custom_mmseg.core import add_prefix
+from custom_mmpkg.custom_mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .encoder_decoder import EncoderDecoder
+
+
+@SEGMENTORS.register_module()
+class CascadeEncoderDecoder(EncoderDecoder):
+ """Cascade Encoder Decoder segmentors.
+
+ CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
+ CascadeEncoderDecoder are cascaded. The output of previous decoder_head
+ will be the input of next decoder_head.
+ """
+
+ def __init__(self,
+ num_stages,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ self.num_stages = num_stages
+ super(CascadeEncoderDecoder, self).__init__(
+ backbone=backbone,
+ decode_head=decode_head,
+ neck=neck,
+ auxiliary_head=auxiliary_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ assert isinstance(decode_head, list)
+ assert len(decode_head) == self.num_stages
+ self.decode_head = nn.ModuleList()
+ for i in range(self.num_stages):
+ self.decode_head.append(builder.build_head(decode_head[i]))
+ self.align_corners = self.decode_head[-1].align_corners
+ self.num_classes = self.decode_head[-1].num_classes
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ self.backbone.init_weights(pretrained=pretrained)
+ for i in range(self.num_stages):
+ self.decode_head[i].init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
+ for i in range(1, self.num_stages):
+ out = self.decode_head[i].forward_test(x, out, img_metas,
+ self.test_cfg)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+
+ loss_decode = self.decode_head[0].forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode_0'))
+
+ for i in range(1, self.num_stages):
+ # forward test again, maybe unnecessary for most methods.
+ prev_outputs = self.decode_head[i - 1].forward_test(
+ x, img_metas, self.test_cfg)
+ loss_decode = self.decode_head[i].forward_train(
+ x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_decode, f'decode_{i}'))
+
+ return losses
diff --git a/src/custom_mmpkg/custom_mmseg/models/segmentors/encoder_decoder.py b/src/custom_mmpkg/custom_mmseg/models/segmentors/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8e42be53aba4b492efec32877bfe4c6a7a1e2aa
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/segmentors/encoder_decoder.py
@@ -0,0 +1,298 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from custom_mmpkg.custom_mmseg.core import add_prefix
+from custom_mmpkg.custom_mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .base import BaseSegmentor
+
+
+@SEGMENTORS.register_module()
+class EncoderDecoder(BaseSegmentor):
+ """Encoder Decoder segmentors.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+
+ def __init__(self,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(EncoderDecoder, self).__init__()
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.init_weights(pretrained=pretrained)
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+
+ def _init_auxiliary_head(self, auxiliary_head):
+ """Initialize ``auxiliary_head``"""
+ if auxiliary_head is not None:
+ if isinstance(auxiliary_head, list):
+ self.auxiliary_head = nn.ModuleList()
+ for head_cfg in auxiliary_head:
+ self.auxiliary_head.append(builder.build_head(head_cfg))
+ else:
+ self.auxiliary_head = builder.build_head(auxiliary_head)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ super(EncoderDecoder, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ self.decode_head.init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return seg_logits
+
+ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+ losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+ else:
+ loss_aux = self.auxiliary_head.forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+
+ return losses
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ seg_logit = self.encode_decode(img, None)
+
+ return seg_logit
+
+ def forward_train(self, img, img_metas, gt_semantic_seg):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(x, img_metas,
+ gt_semantic_seg)
+ losses.update(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(
+ x, img_metas, gt_semantic_seg)
+ losses.update(loss_aux)
+
+ return losses
+
+ # TODO refactor
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ num_classes = self.num_classes
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ crop_seg_logit = self.encode_decode(crop_img, img_meta)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(
+ count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ if rescale:
+ preds = resize(
+ preds,
+ size=img_meta[0]['ori_shape'][:2],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+ return preds
+
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ # support dynamic shape for onnx
+ if torch.onnx.is_in_onnx_export():
+ size = img.shape[2:]
+ else:
+ size = img_meta[0]['ori_shape'][:2]
+ seg_logit = resize(
+ seg_logit,
+ size=size,
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+
+ return seg_logit
+
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output segmentation map.
+ """
+
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = img_meta[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]['flip']
+ if flip:
+ flip_direction = img_meta[0]['flip_direction']
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ output = output.flip(dims=(3, ))
+ elif flip_direction == 'vertical':
+ output = output.flip(dims=(2, ))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ seg_logit = self.inference(img, img_meta, rescale)
+ seg_pred = seg_logit.argmax(dim=1)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ seg_pred = seg_pred.unsqueeze(0)
+ return seg_pred
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(imgs)
+ seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/__init__.py b/src/custom_mmpkg/custom_mmseg/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/__init__.py
@@ -0,0 +1,13 @@
+from .drop import DropPath
+from .inverted_residual import InvertedResidual, InvertedResidualV3
+from .make_divisible import make_divisible
+from .res_layer import ResLayer
+from .se_layer import SELayer
+from .self_attention_block import SelfAttentionBlock
+from .up_conv_block import UpConvBlock
+from .weight_init import trunc_normal_
+
+__all__ = [
+ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
+ 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
+]
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/drop.py b/src/custom_mmpkg/custom_mmseg/models/utils/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/drop.py
@@ -0,0 +1,31 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import torch
+from torch import nn
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ Args:
+ drop_prob (float): Drop rate for paths of model. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self, drop_prob=0.):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.keep_prob = 1 - drop_prob
+
+ def forward(self, x):
+ if self.drop_prob == 0. or not self.training:
+ return x
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = self.keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(self.keep_prob) * random_tensor
+ return output
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/inverted_residual.py b/src/custom_mmpkg/custom_mmseg/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb93391f83c15c91cca833a296b922380607e66
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/inverted_residual.py
@@ -0,0 +1,208 @@
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+from torch import nn
+from torch.utils import checkpoint as cp
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(nn.Module):
+ """InvertedResidual block for MobileNetV2.
+
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ dilation (int): Dilation rate of depthwise conv. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ dilation=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InvertedResidualV3(nn.Module):
+ """Inverted Residual Block for MobileNetV3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels. Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False):
+ super(InvertedResidualV3, self).__init__()
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=mid_channels,
+ conv_cfg=dict(
+ type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+
+ out = self.depthwise_conv(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.linear_conv(out)
+
+ if self.with_res_shortcut:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/make_divisible.py b/src/custom_mmpkg/custom_mmseg/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/make_divisible.py
@@ -0,0 +1,27 @@
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+
+ This function rounds the channel number to the nearest value that can be
+ divisible by the divisor. It is taken from the original tf repo. It ensures
+ that all layers have a channel number that is divisible by divisor. It can
+ be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
+
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float): The minimum ratio of the rounded channel number to
+ the original channel number. Default: 0.9.
+
+ Returns:
+ int: The modified output channel number.
+ """
+
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/res_layer.py b/src/custom_mmpkg/custom_mmseg/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1379e62e7c6e591fae25e57d548c5f735e3ad33a
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/res_layer.py
@@ -0,0 +1,94 @@
+from custom_mmpkg.custom_mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ multi_grid (int | None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ dilation=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ multi_grid=None,
+ contract_dilation=False,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if multi_grid is None:
+ if dilation > 1 and contract_dilation:
+ first_dilation = dilation // 2
+ else:
+ first_dilation = dilation
+ else:
+ first_dilation = multi_grid[0]
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ dilation=first_dilation,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ dilation=dilation if multi_grid is None else multi_grid[i],
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/se_layer.py b/src/custom_mmpkg/custom_mmseg/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec17f8d6713a441de2400186ded50ff651821ca
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/se_layer.py
@@ -0,0 +1,57 @@
+import custom_mmpkg.custom_mmcv as mmcv
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule
+
+from .make_divisible import make_divisible
+
+
+class SELayer(nn.Module):
+ """Squeeze-and-Excitation Module.
+
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configured
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configured by the first dict and the
+ second activation layer will be configured by the second dict.
+ Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+ divisor=6.0)).
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0))):
+ super(SELayer, self).__init__()
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=make_divisible(channels // ratio, 8),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=make_divisible(channels // ratio, 8),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/self_attention_block.py b/src/custom_mmpkg/custom_mmseg/models/utils/self_attention_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad20ca6cf14b4dce040f350dbdd0fee6ce5ed9cf
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/self_attention_block.py
@@ -0,0 +1,159 @@
+import torch
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, constant_init
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class SelfAttentionBlock(nn.Module):
+ """General self-attention block/non-local block.
+
+ Please refer to https://arxiv.org/abs/1706.03762 for details about key,
+ query and value.
+
+ Args:
+ key_in_channels (int): Input channels of key feature.
+ query_in_channels (int): Input channels of query feature.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_downsample (nn.Module): Query downsample module.
+ key_downsample (nn.Module): Key downsample module.
+ key_query_num_convs (int): Number of convs for key/query projection.
+ value_num_convs (int): Number of convs for value projection.
+ matmul_norm (bool): Whether normalize attention map with sqrt of
+ channels
+ with_out (bool): Whether use out projection.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, key_in_channels, query_in_channels, channels,
+ out_channels, share_key_query, query_downsample,
+ key_downsample, key_query_num_convs, value_out_num_convs,
+ key_query_norm, value_out_norm, matmul_norm, with_out,
+ conv_cfg, norm_cfg, act_cfg):
+ super(SelfAttentionBlock, self).__init__()
+ if share_key_query:
+ assert key_in_channels == query_in_channels
+ self.key_in_channels = key_in_channels
+ self.query_in_channels = query_in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.share_key_query = share_key_query
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.key_project = self.build_project(
+ key_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if share_key_query:
+ self.query_project = self.key_project
+ else:
+ self.query_project = self.build_project(
+ query_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.value_project = self.build_project(
+ key_in_channels,
+ channels if with_out else out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if with_out:
+ self.out_project = self.build_project(
+ channels,
+ out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.out_project = None
+
+ self.query_downsample = query_downsample
+ self.key_downsample = key_downsample
+ self.matmul_norm = matmul_norm
+
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize weight of later layer."""
+ if self.out_project is not None:
+ if not isinstance(self.out_project, ConvModule):
+ constant_init(self.out_project, 0)
+
+ def build_project(self, in_channels, channels, num_convs, use_conv_module,
+ conv_cfg, norm_cfg, act_cfg):
+ """Build projection layer for key/query/value/out."""
+ if use_conv_module:
+ convs = [
+ ConvModule(
+ in_channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ ]
+ for _ in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ else:
+ convs = [nn.Conv2d(in_channels, channels, 1)]
+ for _ in range(num_convs - 1):
+ convs.append(nn.Conv2d(channels, channels, 1))
+ if len(convs) > 1:
+ convs = nn.Sequential(*convs)
+ else:
+ convs = convs[0]
+ return convs
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ batch_size = query_feats.size(0)
+ query = self.query_project(query_feats)
+ if self.query_downsample is not None:
+ query = self.query_downsample(query)
+ query = query.reshape(*query.shape[:2], -1)
+ query = query.permute(0, 2, 1).contiguous()
+
+ key = self.key_project(key_feats)
+ value = self.value_project(key_feats)
+ if self.key_downsample is not None:
+ key = self.key_downsample(key)
+ value = self.key_downsample(value)
+ key = key.reshape(*key.shape[:2], -1)
+ value = value.reshape(*value.shape[:2], -1)
+ value = value.permute(0, 2, 1).contiguous()
+
+ sim_map = torch.matmul(query, key)
+ if self.matmul_norm:
+ sim_map = (self.channels**-.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.reshape(batch_size, -1, *query_feats.shape[2:])
+ if self.out_project is not None:
+ context = self.out_project(context)
+ return context
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/up_conv_block.py b/src/custom_mmpkg/custom_mmseg/models/utils/up_conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4320e261db00a8bd0ba2578bcf2fdde952d0270
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/up_conv_block.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from custom_mmpkg.custom_mmcv.cnn import ConvModule, build_upsample_layer
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super(UpConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
diff --git a/src/custom_mmpkg/custom_mmseg/models/utils/weight_init.py b/src/custom_mmpkg/custom_mmseg/models/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/models/utils/weight_init.py
@@ -0,0 +1,62 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import math
+import warnings
+
+import torch
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
+ /truncated_normal.pdf"""
+
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower_bound = norm_cdf((a - mean) / std)
+ upper_bound = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
+ mean (float): the mean of the normal distribution
+ std (float): the standard deviation of the normal distribution
+ a (float): the minimum cutoff value
+ b (float): the maximum cutoff value
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/src/custom_mmpkg/custom_mmseg/ops/__init__.py b/src/custom_mmpkg/custom_mmseg/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/ops/__init__.py
@@ -0,0 +1,4 @@
+from .encoding import Encoding
+from .wrappers import Upsample, resize
+
+__all__ = ['Upsample', 'resize', 'Encoding']
diff --git a/src/custom_mmpkg/custom_mmseg/ops/encoding.py b/src/custom_mmpkg/custom_mmseg/ops/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/ops/encoding.py
@@ -0,0 +1,74 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Encoding(nn.Module):
+ """Encoding Layer: a learnable residual encoder.
+
+ Input is of shape (batch_size, channels, height, width).
+ Output is of shape (batch_size, num_codes, channels).
+
+ Args:
+ channels: dimension of the features or feature channels
+ num_codes: number of code words
+ """
+
+ def __init__(self, channels, num_codes):
+ super(Encoding, self).__init__()
+ # init codewords and smoothing factor
+ self.channels, self.num_codes = channels, num_codes
+ std = 1. / ((num_codes * channels)**0.5)
+ # [num_codes, channels]
+ self.codewords = nn.Parameter(
+ torch.empty(num_codes, channels,
+ dtype=torch.float).uniform_(-std, std),
+ requires_grad=True)
+ # [num_codes]
+ self.scale = nn.Parameter(
+ torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+ requires_grad=True)
+
+ @staticmethod
+ def scaled_l2(x, codewords, scale):
+ num_codes, channels = codewords.size()
+ batch_size = x.size(0)
+ reshaped_scale = scale.view((1, 1, num_codes))
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+
+ scaled_l2_norm = reshaped_scale * (
+ expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+ return scaled_l2_norm
+
+ @staticmethod
+ def aggregate(assignment_weights, x, codewords):
+ num_codes, channels = codewords.size()
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+ batch_size = x.size(0)
+
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ encoded_feat = (assignment_weights.unsqueeze(3) *
+ (expanded_x - reshaped_codewords)).sum(dim=1)
+ return encoded_feat
+
+ def forward(self, x):
+ assert x.dim() == 4 and x.size(1) == self.channels
+ # [batch_size, channels, height, width]
+ batch_size = x.size(0)
+ # [batch_size, height x width, channels]
+ x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+ # assignment_weights: [batch_size, channels, num_codes]
+ assignment_weights = F.softmax(
+ self.scaled_l2(x, self.codewords, self.scale), dim=2)
+ # aggregate
+ encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+ return encoded_feat
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+ f'x{self.channels})'
+ return repr_str
diff --git a/src/custom_mmpkg/custom_mmseg/ops/wrappers.py b/src/custom_mmpkg/custom_mmseg/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/ops/wrappers.py
@@ -0,0 +1,50 @@
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def resize(input,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None,
+ warning=True):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Upsample(nn.Module):
+
+ def __init__(self,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None):
+ super(Upsample, self).__init__()
+ self.size = size
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ if not self.size:
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+ else:
+ size = self.size
+ return resize(x, size, None, self.mode, self.align_corners)
diff --git a/src/custom_mmpkg/custom_mmseg/utils/__init__.py b/src/custom_mmpkg/custom_mmseg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+
+__all__ = ['get_root_logger', 'collect_env']
diff --git a/src/custom_mmpkg/custom_mmseg/utils/collect_env.py b/src/custom_mmpkg/custom_mmseg/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5cd6ffee77a234e7c54d6990d273bbd872b8f5
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/utils/collect_env.py
@@ -0,0 +1,17 @@
+from custom_mmpkg.custom_mmcv.utils import collect_env as collect_base_env
+from custom_mmpkg.custom_mmcv.utils import get_git_hash
+
+import custom_mmpkg.custom_mmseg as mmseg
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print('{}: {}'.format(name, val))
diff --git a/src/custom_mmpkg/custom_mmseg/utils/logger.py b/src/custom_mmpkg/custom_mmseg/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc2ac05b20a1f7f24a3e7876757d87b00972a69d
--- /dev/null
+++ b/src/custom_mmpkg/custom_mmseg/utils/logger.py
@@ -0,0 +1,27 @@
+import logging
+
+from custom_mmpkg.custom_mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmseg".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+
+ logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
+
+ return logger
diff --git a/src/wrapper_for_mps/__init__.py b/src/wrapper_for_mps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d40bcf5d6be0ff616c8704bdf8c699fcda876ba8
--- /dev/null
+++ b/src/wrapper_for_mps/__init__.py
@@ -0,0 +1,7 @@
+import torch
+from comfy.model_management import get_torch_device
+
+device = get_torch_device()
+#https://github.com/microsoft/DirectML/issues/414#issuecomment-1541319479
+def sparse_to_dense(sparse_tensor):
+ return sparse_tensor.to_dense()
\ No newline at end of file
diff --git a/tests/pose.png b/tests/pose.png
new file mode 100644
index 0000000000000000000000000000000000000000..5a04a9b71a45736c7877e8e8cb6b268c1cfc7192
--- /dev/null
+++ b/tests/pose.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94b9d256939537f0e24d66ec6d6690d06132a1f93775506f58dd6a6bb4354941
+size 591639
diff --git a/tests/test_cn_aux_full.json b/tests/test_cn_aux_full.json
new file mode 100644
index 0000000000000000000000000000000000000000..b6445f36944ab43f99aa7640cfedc3968eb2f6d7
--- /dev/null
+++ b/tests/test_cn_aux_full.json
@@ -0,0 +1,1737 @@
+{
+ "last_node_id": 45,
+ "last_link_id": 44,
+ "nodes": [
+ {
+ "id": 24,
+ "type": "PreviewImage",
+ "pos": [
+ 843,
+ -430
+ ],
+ "size": [
+ 210,
+ 246
+ ],
+ "flags": {},
+ "order": 22,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 23
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 25,
+ "type": "PreviewImage",
+ "pos": [
+ 1127,
+ -346
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 23,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 24
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 26,
+ "type": "PreviewImage",
+ "pos": [
+ 832,
+ -222
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 24,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 25
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 27,
+ "type": "PreviewImage",
+ "pos": [
+ 1144,
+ -123
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 25,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 26
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 28,
+ "type": "PreviewImage",
+ "pos": [
+ 825,
+ 56
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 26,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 27
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 29,
+ "type": "PreviewImage",
+ "pos": [
+ 1240,
+ 246
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 27,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 28
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 30,
+ "type": "PreviewImage",
+ "pos": [
+ 855,
+ 381
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 28,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 29
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 31,
+ "type": "PreviewImage",
+ "pos": [
+ 1248,
+ 471
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 29,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 30
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 32,
+ "type": "PreviewImage",
+ "pos": [
+ 823,
+ 632
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 30,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 31
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 33,
+ "type": "PreviewImage",
+ "pos": [
+ 1240,
+ 737
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 31,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 32
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 34,
+ "type": "PreviewImage",
+ "pos": [
+ 844,
+ 833
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 32,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 33
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 35,
+ "type": "PreviewImage",
+ "pos": [
+ 1216,
+ 1023
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 33,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 34
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 36,
+ "type": "PreviewImage",
+ "pos": [
+ 838,
+ 1175
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 34,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 35
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 37,
+ "type": "PreviewImage",
+ "pos": [
+ 1282,
+ 1355
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 35,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 36
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 38,
+ "type": "PreviewImage",
+ "pos": [
+ 897,
+ 1532
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 36,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 37
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 39,
+ "type": "PreviewImage",
+ "pos": [
+ 1336,
+ 1704
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 37,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 38
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 40,
+ "type": "PreviewImage",
+ "pos": [
+ 859,
+ 1840
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 38,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 39
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 41,
+ "type": "PreviewImage",
+ "pos": [
+ 1329,
+ 1939
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 39,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 40
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 42,
+ "type": "PreviewImage",
+ "pos": [
+ 888,
+ 2056
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 40,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 42
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 43,
+ "type": "PreviewImage",
+ "pos": [
+ 1278,
+ 2191
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 41,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 41
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 2,
+ "type": "PiDiNetPreprocessor",
+ "pos": [
+ 420,
+ -446
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 1,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 1
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 23
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PiDiNetPreprocessor"
+ },
+ "widgets_values": [
+ "enable"
+ ]
+ },
+ {
+ "id": 3,
+ "type": "ColorPreprocessor",
+ "pos": [
+ 426,
+ -332
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 2,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 2
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 24
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ColorPreprocessor"
+ }
+ },
+ {
+ "id": 4,
+ "type": "CannyEdgePreprocessor",
+ "pos": [
+ 433,
+ -245
+ ],
+ "size": {
+ "0": 315,
+ "1": 82
+ },
+ "flags": {},
+ "order": 3,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 3
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 25
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "CannyEdgePreprocessor"
+ },
+ "widgets_values": [
+ 100,
+ 200
+ ]
+ },
+ {
+ "id": 5,
+ "type": "SAMPreprocessor",
+ "pos": [
+ 427,
+ -108
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 4,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 4
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 26
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "SAMPreprocessor"
+ }
+ },
+ {
+ "id": 7,
+ "type": "DWPreprocessor",
+ "pos": [
+ 440,
+ 95
+ ],
+ "size": {
+ "0": 315,
+ "1": 106
+ },
+ "flags": {},
+ "order": 5,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 6
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 27
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "DWPreprocessor"
+ },
+ "widgets_values": [
+ "enable",
+ "enable",
+ "enable"
+ ]
+ },
+ {
+ "id": 8,
+ "type": "BinaryPreprocessor",
+ "pos": [
+ 432,
+ 266
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 6,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 7
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 28
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "BinaryPreprocessor"
+ },
+ "widgets_values": [
+ 100
+ ]
+ },
+ {
+ "id": 9,
+ "type": "ScribblePreprocessor",
+ "pos": [
+ 462,
+ 376
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 7,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 8
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 29
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ScribblePreprocessor"
+ }
+ },
+ {
+ "id": 10,
+ "type": "M-LSDPreprocessor",
+ "pos": [
+ 453,
+ 497
+ ],
+ "size": {
+ "0": 315,
+ "1": 82
+ },
+ "flags": {},
+ "order": 8,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 9
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 30
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "M-LSDPreprocessor"
+ },
+ "widgets_values": [
+ 0.1,
+ 0.1
+ ]
+ },
+ {
+ "id": 11,
+ "type": "UniFormer-SemSegPreprocessor",
+ "pos": [
+ 479,
+ 651
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 9,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 10
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 31
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "UniFormer-SemSegPreprocessor"
+ }
+ },
+ {
+ "id": 12,
+ "type": "Zoe-DepthMapPreprocessor",
+ "pos": [
+ 483,
+ 740
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 10,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 11
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 32
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "Zoe-DepthMapPreprocessor"
+ }
+ },
+ {
+ "id": 13,
+ "type": "MiDaS-NormalMapPreprocessor",
+ "pos": [
+ 463,
+ 821
+ ],
+ "size": {
+ "0": 315,
+ "1": 82
+ },
+ "flags": {},
+ "order": 11,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 12
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 33
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MiDaS-NormalMapPreprocessor"
+ },
+ "widgets_values": [
+ 6.283185307179586,
+ 0.1
+ ]
+ },
+ {
+ "id": 14,
+ "type": "MiDaS-DepthMapPreprocessor",
+ "pos": [
+ 451,
+ 1009
+ ],
+ "size": {
+ "0": 315,
+ "1": 82
+ },
+ "flags": {},
+ "order": 12,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 13
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 34
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "MiDaS-DepthMapPreprocessor"
+ },
+ "widgets_values": [
+ 6.283185307179586,
+ 0.1
+ ]
+ },
+ {
+ "id": 15,
+ "type": "OpenposePreprocessor",
+ "pos": [
+ 466,
+ 1177
+ ],
+ "size": {
+ "0": 315,
+ "1": 106
+ },
+ "flags": {},
+ "order": 13,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 14
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 35
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "OpenposePreprocessor"
+ },
+ "widgets_values": [
+ "enable",
+ "enable",
+ "enable"
+ ]
+ },
+ {
+ "id": 17,
+ "type": "LeReS-DepthMapPreprocessor",
+ "pos": [
+ 484,
+ 1533
+ ],
+ "size": {
+ "0": 315,
+ "1": 106
+ },
+ "flags": {},
+ "order": 15,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 16
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 37
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LeReS-DepthMapPreprocessor"
+ },
+ "widgets_values": [
+ 0,
+ 0,
+ "enable"
+ ]
+ },
+ {
+ "id": 18,
+ "type": "BAE-NormalMapPreprocessor",
+ "pos": [
+ 510,
+ 1729
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 16,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 17
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 38
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "BAE-NormalMapPreprocessor"
+ }
+ },
+ {
+ "id": 19,
+ "type": "OneFormer-COCO-SemSegPreprocessor",
+ "pos": [
+ 488,
+ 1843
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 17,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 18
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 39
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "OneFormer-COCO-SemSegPreprocessor"
+ }
+ },
+ {
+ "id": 20,
+ "type": "OneFormer-ADE20K-SemSegPreprocessor",
+ "pos": [
+ 470,
+ 1941
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 18,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 19
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 40
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "OneFormer-ADE20K-SemSegPreprocessor"
+ }
+ },
+ {
+ "id": 22,
+ "type": "FakeScribblePreprocessor",
+ "pos": [
+ 426,
+ 2193
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 20,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 21
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 41
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "FakeScribblePreprocessor"
+ },
+ "widgets_values": [
+ "enable"
+ ]
+ },
+ {
+ "id": 21,
+ "type": "HEDPreprocessor",
+ "pos": [
+ 460,
+ 2053
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 19,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 20
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 42
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "HEDPreprocessor"
+ },
+ "widgets_values": [
+ "enable"
+ ]
+ },
+ {
+ "id": 16,
+ "type": "LineArtPreprocessor",
+ "pos": [
+ 450,
+ 1363
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 14,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 15
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 36
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LineArtPreprocessor"
+ },
+ "widgets_values": [
+ "enable"
+ ]
+ },
+ {
+ "id": 45,
+ "type": "PreviewImage",
+ "pos": [
+ 886,
+ 2316
+ ],
+ "size": {
+ "0": 210,
+ "1": 26
+ },
+ "flags": {},
+ "order": 42,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 43
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "PreviewImage"
+ }
+ },
+ {
+ "id": 44,
+ "type": "TilePreprocessor",
+ "pos": [
+ 419,
+ 2320
+ ],
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 21,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 44
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 43
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "TilePreprocessor"
+ },
+ "widgets_values": [
+ 3
+ ]
+ },
+ {
+ "id": 1,
+ "type": "LoadImage",
+ "pos": [
+ 19,
+ 298
+ ],
+ "size": {
+ "0": 315,
+ "1": 314
+ },
+ "flags": {},
+ "order": 0,
+ "mode": 0,
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 1,
+ 2,
+ 3,
+ 4,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 44
+ ],
+ "shape": 3,
+ "slot_index": 0
+ },
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": null,
+ "shape": 3
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "LoadImage"
+ },
+ "widgets_values": [
+ "pose.png",
+ "image"
+ ]
+ }
+ ],
+ "links": [
+ [
+ 1,
+ 1,
+ 0,
+ 2,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 2,
+ 1,
+ 0,
+ 3,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 3,
+ 1,
+ 0,
+ 4,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 4,
+ 1,
+ 0,
+ 5,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 6,
+ 1,
+ 0,
+ 7,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 7,
+ 1,
+ 0,
+ 8,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 8,
+ 1,
+ 0,
+ 9,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 9,
+ 1,
+ 0,
+ 10,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 10,
+ 1,
+ 0,
+ 11,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 11,
+ 1,
+ 0,
+ 12,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 12,
+ 1,
+ 0,
+ 13,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 13,
+ 1,
+ 0,
+ 14,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 14,
+ 1,
+ 0,
+ 15,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 15,
+ 1,
+ 0,
+ 16,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 16,
+ 1,
+ 0,
+ 17,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 17,
+ 1,
+ 0,
+ 18,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 18,
+ 1,
+ 0,
+ 19,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 19,
+ 1,
+ 0,
+ 20,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 20,
+ 1,
+ 0,
+ 21,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 21,
+ 1,
+ 0,
+ 22,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 23,
+ 2,
+ 0,
+ 24,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 24,
+ 3,
+ 0,
+ 25,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 25,
+ 4,
+ 0,
+ 26,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 26,
+ 5,
+ 0,
+ 27,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 27,
+ 7,
+ 0,
+ 28,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 28,
+ 8,
+ 0,
+ 29,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 29,
+ 9,
+ 0,
+ 30,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 30,
+ 10,
+ 0,
+ 31,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 31,
+ 11,
+ 0,
+ 32,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 32,
+ 12,
+ 0,
+ 33,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 33,
+ 13,
+ 0,
+ 34,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 34,
+ 14,
+ 0,
+ 35,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 35,
+ 15,
+ 0,
+ 36,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 36,
+ 16,
+ 0,
+ 37,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 37,
+ 17,
+ 0,
+ 38,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 38,
+ 18,
+ 0,
+ 39,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 39,
+ 19,
+ 0,
+ 40,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 40,
+ 20,
+ 0,
+ 41,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 41,
+ 22,
+ 0,
+ 43,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 42,
+ 21,
+ 0,
+ 42,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 43,
+ 44,
+ 0,
+ 45,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 44,
+ 1,
+ 0,
+ 44,
+ 0,
+ "IMAGE"
+ ]
+ ],
+ "groups": [],
+ "config": {},
+ "extra": {},
+ "version": 0.4
+}
\ No newline at end of file
diff --git a/tests/test_controlnet_aux.py b/tests/test_controlnet_aux.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fdaaca498a58eecd02439e2611e605835d5929f
--- /dev/null
+++ b/tests/test_controlnet_aux.py
@@ -0,0 +1,126 @@
+import os
+import shutil
+from io import BytesIO
+
+import numpy as np
+import pytest
+import requests
+from PIL import Image
+
+from custom_controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
+ LeresDetector, LineartAnimeDetector,
+ LineartDetector, MediapipeFaceDetector,
+ MidasDetector, MLSDdetector, NormalBaeDetector,
+ OpenposeDetector, PidiNetDetector, SamDetector,
+ ZoeDetector, TileDetector)
+
+OUTPUT_DIR = "tests/outputs"
+
+def output(name, img):
+ img.save(os.path.join(OUTPUT_DIR, "{:s}.png".format(name)))
+
+def common(name, processor, img):
+ output(name, processor(img))
+ output(name + "_pil_np", Image.fromarray(processor(img, output_type="np")))
+ output(name + "_np_np", Image.fromarray(processor(np.array(img, dtype=np.uint8), output_type="np")))
+ output(name + "_np_pil", processor(np.array(img, dtype=np.uint8), output_type="pil"))
+ output(name + "_scaled", processor(img, detect_resolution=640, image_resolution=768))
+
+def return_pil(name, processor, img):
+ output(name + "_pil_false", Image.fromarray(processor(img, return_pil=False)))
+ output(name + "_pil_true", processor(img, return_pil=True))
+
+@pytest.fixture(scope="module")
+def img():
+ if os.path.exists(OUTPUT_DIR):
+ shutil.rmtree(OUTPUT_DIR)
+ os.mkdir(OUTPUT_DIR)
+ url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png"
+ response = requests.get(url)
+ img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))
+ return img
+
+def test_canny(img):
+ canny = CannyDetector()
+ common("canny", canny, img)
+ output("canny_img", canny(img=img))
+
+def test_hed(img):
+ hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
+ common("hed", hed, img)
+ return_pil("hed", hed, img)
+ output("hed_safe", hed(img, safe=True))
+ output("hed_scribble", hed(img, scribble=True))
+
+def test_leres(img):
+ leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
+ common("leres", leres, img)
+ output("leres_boost", leres(img, boost=True))
+
+def test_lineart(img):
+ lineart = LineartDetector.from_pretrained("lllyasviel/Annotators")
+ common("lineart", lineart, img)
+ return_pil("lineart", lineart, img)
+ output("lineart_coarse", lineart(img, coarse=True))
+
+def test_lineart_anime(img):
+ lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
+ common("lineart_anime", lineart_anime, img)
+ return_pil("lineart_anime", lineart_anime, img)
+
+def test_mediapipe_face(img):
+ mediapipe = MediapipeFaceDetector()
+ common("mediapipe", mediapipe, img)
+ output("mediapipe_image", mediapipe(image=img))
+
+def test_midas(img):
+ midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
+ common("midas", midas, img)
+ output("midas_normal", midas(img, depth_and_normal=True)[1])
+
+def test_mlsd(img):
+ mlsd = MLSDdetector.from_pretrained("lllyasviel/Annotators")
+ common("mlsd", mlsd, img)
+ return_pil("mlsd", mlsd, img)
+
+def test_normalbae(img):
+ normal_bae = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
+ common("normal_bae", normal_bae, img)
+ return_pil("normal_bae", normal_bae, img)
+
+def test_openpose(img):
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+ common("openpose", openpose, img)
+ return_pil("openpose", openpose, img)
+ output("openpose_hand_and_face_false", openpose(img, hand_and_face=False))
+ output("openpose_hand_and_face_true", openpose(img, hand_and_face=True))
+ output("openpose_face", openpose(img, include_body=True, include_hand=False, include_face=True))
+ output("openpose_faceonly", openpose(img, include_body=False, include_hand=False, include_face=True))
+ output("openpose_full", openpose(img, include_body=True, include_hand=True, include_face=True))
+ output("openpose_hand", openpose(img, include_body=True, include_hand=True, include_face=False))
+
+def test_pidi(img):
+ pidi = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
+ common("pidi", pidi, img)
+ return_pil("pidi", pidi, img)
+ output("pidi_safe", pidi(img, safe=True))
+ output("pidi_scribble", pidi(img, scribble=True))
+
+def test_sam(img):
+ sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
+ common("sam", sam, img)
+ output("sam_image", sam(image=img))
+
+def test_shuffle(img):
+ shuffle = ContentShuffleDetector()
+ common("shuffle", shuffle, img)
+ return_pil("shuffle", shuffle, img)
+
+def test_zoe(img):
+ zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
+ common("zoe", zoe, img)
+
+def test_tile(img):
+ tile = TileDetector()
+ common("tile", tile, img)
+ output("tile_img", tile(img))
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf3dd6e693ba379ef4714745bddeca88d0b633f9
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,250 @@
+import torch
+import numpy as np
+import os
+import cv2
+import yaml
+from pathlib import Path
+from enum import Enum
+from .log import log
+import subprocess
+import threading
+import comfy
+import tempfile
+
+here = Path(__file__).parent.resolve()
+
+config_path = Path(here, "config.yaml")
+
+if os.path.exists(config_path):
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
+
+ annotator_ckpts_path = str(Path(here, config["annotator_ckpts_path"]))
+ TEMP_DIR = config["custom_temp_path"]
+ USE_SYMLINKS = config["USE_SYMLINKS"]
+ ORT_PROVIDERS = config["EP_list"]
+
+ if USE_SYMLINKS is None or type(USE_SYMLINKS) != bool:
+ log.error("USE_SYMLINKS must be a boolean. Using False by default.")
+ USE_SYMLINKS = False
+
+ if TEMP_DIR is None:
+ TEMP_DIR = tempfile.gettempdir()
+ elif not os.path.isdir(TEMP_DIR):
+ try:
+ os.makedirs(TEMP_DIR)
+ except:
+ log.error("Failed to create custom temp directory. Using default.")
+ TEMP_DIR = tempfile.gettempdir()
+
+ if not os.path.isdir(annotator_ckpts_path):
+ try:
+ os.makedirs(annotator_ckpts_path)
+ except:
+ log.error("Failed to create config ckpts directory. Using default.")
+ annotator_ckpts_path = str(Path(here, "./ckpts"))
+else:
+ annotator_ckpts_path = str(Path(here, "./ckpts"))
+ TEMP_DIR = tempfile.gettempdir()
+ USE_SYMLINKS = False
+ ORT_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider", "CoreMLExecutionProvider"]
+
+os.environ['AUX_ANNOTATOR_CKPTS_PATH'] = os.getenv('AUX_ANNOTATOR_CKPTS_PATH', annotator_ckpts_path)
+os.environ['AUX_TEMP_DIR'] = os.getenv('AUX_TEMP_DIR', str(TEMP_DIR))
+os.environ['AUX_USE_SYMLINKS'] = os.getenv('AUX_USE_SYMLINKS', str(USE_SYMLINKS))
+os.environ['AUX_ORT_PROVIDERS'] = os.getenv('AUX_ORT_PROVIDERS', str(",".join(ORT_PROVIDERS)))
+
+log.info(f"Using ckpts path: {annotator_ckpts_path}")
+log.info(f"Using symlinks: {USE_SYMLINKS}")
+log.info(f"Using ort providers: {ORT_PROVIDERS}")
+
+# Sync with theoritical limit from Comfy base
+# https://github.com/comfyanonymous/ComfyUI/blob/eecd69b53a896343775bcb02a4f8349e7442ffd1/nodes.py#L45
+MAX_RESOLUTION=16384
+
+def common_annotator_call(model, tensor_image, input_batch=False, show_pbar=True, **kwargs):
+ if "detect_resolution" in kwargs:
+ del kwargs["detect_resolution"] #Prevent weird case?
+
+ if "resolution" in kwargs:
+ detect_resolution = kwargs["resolution"] if type(kwargs["resolution"]) == int and kwargs["resolution"] >= 64 else 512
+ del kwargs["resolution"]
+ else:
+ detect_resolution = 512
+
+ if input_batch:
+ np_images = np.asarray(tensor_image * 255., dtype=np.uint8)
+ np_results = model(np_images, output_type="np", detect_resolution=detect_resolution, **kwargs)
+ return torch.from_numpy(np_results.astype(np.float32) / 255.0)
+
+ batch_size = tensor_image.shape[0]
+ if show_pbar:
+ pbar = comfy.utils.ProgressBar(batch_size)
+ out_tensor = None
+ for i, image in enumerate(tensor_image):
+ np_image = np.asarray(image.cpu() * 255., dtype=np.uint8)
+ np_result = model(np_image, output_type="np", detect_resolution=detect_resolution, **kwargs)
+ out = torch.from_numpy(np_result.astype(np.float32) / 255.0)
+ if out_tensor is None:
+ out_tensor = torch.zeros(batch_size, *out.shape, dtype=torch.float32)
+ out_tensor[i] = out
+ if show_pbar:
+ pbar.update(1)
+ return out_tensor
+
+def define_preprocessor_inputs(**arguments):
+ return dict(
+ required=dict(image=INPUT.IMAGE()),
+ optional=arguments
+ )
+
+class INPUT(Enum):
+ def IMAGE():
+ return ("IMAGE",)
+ def LATENT():
+ return ("LATENT",)
+ def MASK():
+ return ("MASK",)
+ def SEED(default=0):
+ return ("INT", dict(default=default, min=0, max=0xffffffffffffffff))
+ def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64):
+ return ("INT", dict(default=default, min=min, max=max, step=step))
+ def INT(default=0, min=0, max=MAX_RESOLUTION, step=1):
+ return ("INT", dict(default=default, min=min, max=max, step=step))
+ def FLOAT(default=0, min=0, max=1, step=0.01):
+ return ("FLOAT", dict(default=default, min=min, max=max, step=step))
+ def STRING(default='', multiline=False):
+ return ("STRING", dict(default=default, multiline=multiline))
+ def COMBO(values, default=None):
+ return (values, dict(default=values[0] if default is None else default))
+ def BOOLEAN(default=True):
+ return ("BOOLEAN", dict(default=default))
+
+
+
+class ResizeMode(Enum):
+ """
+ Resize modes for ControlNet input images.
+ """
+
+ RESIZE = "Just Resize"
+ INNER_FIT = "Crop and Resize"
+ OUTER_FIT = "Resize and Fill"
+
+ def int_value(self):
+ if self == ResizeMode.RESIZE:
+ return 0
+ elif self == ResizeMode.INNER_FIT:
+ return 1
+ elif self == ResizeMode.OUTER_FIT:
+ return 2
+ assert False, "NOTREACHED"
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/internal_controlnet/external_code.py#L89
+#Replaced logger with internal log
+def pixel_perfect_resolution(
+ image: np.ndarray,
+ target_H: int,
+ target_W: int,
+ resize_mode: ResizeMode,
+) -> int:
+ """
+ Calculate the estimated resolution for resizing an image while preserving aspect ratio.
+
+ The function first calculates scaling factors for height and width of the image based on the target
+ height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
+ scaling factor to estimate the new resolution.
+
+ If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
+ fits within the target dimensions, potentially leaving some empty space.
+
+ If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
+ dimensions are fully filled, potentially cropping the image.
+
+ After calculating the estimated resolution, the function prints some debugging information.
+
+ Args:
+ image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
+ target_H (int): The target height for the image.
+ target_W (int): The target width for the image.
+ resize_mode (ResizeMode): The mode for resizing.
+
+ Returns:
+ int: The estimated resolution after resizing.
+ """
+ raw_H, raw_W, _ = image.shape
+
+ k0 = float(target_H) / float(raw_H)
+ k1 = float(target_W) / float(raw_W)
+
+ if resize_mode == ResizeMode.OUTER_FIT:
+ estimation = min(k0, k1) * float(min(raw_H, raw_W))
+ else:
+ estimation = max(k0, k1) * float(min(raw_H, raw_W))
+
+ log.debug(f"Pixel Perfect Computation:")
+ log.debug(f"resize_mode = {resize_mode}")
+ log.debug(f"raw_H = {raw_H}")
+ log.debug(f"raw_W = {raw_W}")
+ log.debug(f"target_H = {target_H}")
+ log.debug(f"target_W = {target_W}")
+ log.debug(f"estimation = {estimation}")
+
+ return int(np.round(estimation))
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/scripts/controlnet.py#L404
+def safe_numpy(x):
+ # A very safe method to make sure that Apple/Mac works
+ y = x
+
+ # below is very boring but do not change these. If you change these Apple or Mac may fail.
+ y = y.copy()
+ y = np.ascontiguousarray(y)
+ y = y.copy()
+ return y
+
+#https://github.com/Mikubill/sd-webui-controlnet/blob/e67e017731aad05796b9615dc6eadce911298ea1/scripts/utils.py#L140
+def get_unique_axis0(data):
+ arr = np.asanyarray(data)
+ idxs = np.lexsort(arr.T)
+ arr = arr[idxs]
+ unique_idxs = np.empty(len(arr), dtype=np.bool_)
+ unique_idxs[:1] = True
+ unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
+ return arr[unique_idxs]
+
+#Ref: https://github.com/ltdrdata/ComfyUI-Manager/blob/284e90dc8296a2e1e4f14b4b2d10fba2f52f0e53/__init__.py#L14
+def handle_stream(stream, prefix):
+ for line in stream:
+ print(prefix, line, end="")
+
+
+def run_script(cmd, cwd='.'):
+ process = subprocess.Popen(cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
+
+ stdout_thread = threading.Thread(target=handle_stream, args=(process.stdout, ""))
+ stderr_thread = threading.Thread(target=handle_stream, args=(process.stderr, "[!]"))
+
+ stdout_thread.start()
+ stderr_thread.start()
+
+ stdout_thread.join()
+ stderr_thread.join()
+
+ return process.wait()
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z