Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +23 -0
- .gitignore +14 -0
- LICENSE +107 -0
- README.md +191 -8
- app.py +391 -0
- app_flux.py +305 -0
- app_p2p.py +567 -0
- densepose/__init__.py +22 -0
- densepose/config.py +277 -0
- densepose/converters/__init__.py +17 -0
- densepose/converters/base.py +95 -0
- densepose/converters/builtin.py +33 -0
- densepose/converters/chart_output_hflip.py +73 -0
- densepose/converters/chart_output_to_chart_result.py +190 -0
- densepose/converters/hflip.py +36 -0
- densepose/converters/segm_to_mask.py +152 -0
- densepose/converters/to_chart_result.py +72 -0
- densepose/converters/to_mask.py +51 -0
- densepose/data/__init__.py +27 -0
- densepose/data/build.py +738 -0
- densepose/data/combined_loader.py +46 -0
- densepose/data/dataset_mapper.py +170 -0
- densepose/data/image_list_dataset.py +74 -0
- densepose/data/inference_based_loader.py +174 -0
- densepose/data/meshes/__init__.py +7 -0
- densepose/data/meshes/builtin.py +103 -0
- densepose/data/meshes/catalog.py +73 -0
- densepose/data/samplers/__init__.py +10 -0
- densepose/data/samplers/densepose_base.py +205 -0
- densepose/data/samplers/densepose_confidence_based.py +110 -0
- densepose/data/samplers/densepose_cse_base.py +141 -0
- densepose/data/samplers/densepose_cse_confidence_based.py +121 -0
- densepose/data/samplers/densepose_cse_uniform.py +14 -0
- densepose/data/samplers/densepose_uniform.py +43 -0
- densepose/data/samplers/mask_from_densepose.py +30 -0
- densepose/data/samplers/prediction_to_gt.py +100 -0
- densepose/data/transform/__init__.py +5 -0
- densepose/data/transform/image.py +41 -0
- densepose/data/utils.py +40 -0
- densepose/data/video/__init__.py +19 -0
- densepose/data/video/frame_selector.py +89 -0
- densepose/data/video/video_keyframe_dataset.py +304 -0
- densepose/engine/__init__.py +5 -0
- densepose/engine/trainer.py +260 -0
- densepose/evaluation/__init__.py +5 -0
- densepose/evaluation/d2_evaluator_adapter.py +52 -0
- densepose/evaluation/densepose_coco_evaluation.py +1305 -0
- densepose/evaluation/evaluator.py +423 -0
- densepose/evaluation/mesh_alignment_evaluator.py +68 -0
- densepose/evaluation/tensor_storage.py +241 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
resource/demo/example/condition/overall/21744571_51588794_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
resource/demo/example/condition/overall/23962182_54027982_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
resource/demo/example/condition/person/baumu30483223c3_1719437121402_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
resource/demo/example/condition/person/mison407622250d_1719258948458_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
resource/demo/example/condition/person/mothr22044226e8_1718142523286_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
resource/demo/example/condition/upper/21514384_52353349_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
resource/demo/example/condition/upper/22790049_53294275_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
resource/demo/example/condition/upper/24083449_54173465_2048.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
resource/demo/example/person/men/Simon_1.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
resource/demo/example/person/men/Yifeng_0.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
resource/demo/example/person/men/model_5.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
resource/demo/example/person/men/model_7.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
resource/demo/example/person/women/049713_0.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
resource/demo/example/person/women/1-model_3.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
resource/demo/example/person/women/2-model_4.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
resource/demo/example/person/women/model_8.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
resource/img/architecture.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
resource/img/comfyui-1.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
resource/img/comfyui.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
resource/img/efficency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
resource/img/structure.jpg filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
resource/img/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
model/__pycache__
|
| 3 |
+
model/DensePose/__pycache__
|
| 4 |
+
model/SCHP/__pycache__
|
| 5 |
+
model/SCHP/*/__pycache__
|
| 6 |
+
resource/demo/output
|
| 7 |
+
resource/demo/example/.DS_Store
|
| 8 |
+
Models
|
| 9 |
+
Datasets
|
| 10 |
+
densepose_
|
| 11 |
+
.vscode
|
| 12 |
+
playground.py
|
| 13 |
+
output
|
| 14 |
+
model/cloth_masker_segformer.py
|
LICENSE
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 2 |
+
|
| 3 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
| 4 |
+
|
| 5 |
+
Using Creative Commons Public Licenses
|
| 6 |
+
|
| 7 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
| 8 |
+
|
| 9 |
+
Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
|
| 10 |
+
|
| 11 |
+
Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
|
| 12 |
+
|
| 13 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 14 |
+
|
| 15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
| 16 |
+
|
| 17 |
+
Section 1 – Definitions.
|
| 18 |
+
|
| 19 |
+
a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
| 20 |
+
b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
| 21 |
+
c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
|
| 22 |
+
d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
| 23 |
+
e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
| 24 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
| 25 |
+
g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
|
| 26 |
+
h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
| 27 |
+
i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
| 28 |
+
j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
|
| 29 |
+
k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
| 30 |
+
l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
| 31 |
+
m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
| 32 |
+
n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
| 33 |
+
Section 2 – Scope.
|
| 34 |
+
|
| 35 |
+
a. License grant.
|
| 36 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
| 37 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
| 38 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
| 39 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
| 40 |
+
3. Term. The term of this Public License is specified in Section 6(a).
|
| 41 |
+
4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
| 42 |
+
5. Downstream recipients.
|
| 43 |
+
A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
| 44 |
+
B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
|
| 45 |
+
C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
| 46 |
+
6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
| 47 |
+
b. Other rights.
|
| 48 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
| 49 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
| 50 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
| 51 |
+
Section 3 – License Conditions.
|
| 52 |
+
|
| 53 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
| 54 |
+
|
| 55 |
+
a. Attribution.
|
| 56 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
| 57 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
| 58 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
| 59 |
+
ii. a copyright notice;
|
| 60 |
+
iii. a notice that refers to this Public License;
|
| 61 |
+
iv. a notice that refers to the disclaimer of warranties;
|
| 62 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
| 63 |
+
|
| 64 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
| 65 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
| 66 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
| 67 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
| 68 |
+
b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
|
| 69 |
+
1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
|
| 70 |
+
2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
|
| 71 |
+
3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
|
| 72 |
+
Section 4 – Sui Generis Database Rights.
|
| 73 |
+
|
| 74 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
| 75 |
+
|
| 76 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
| 77 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
|
| 78 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
| 79 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
| 80 |
+
Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
| 81 |
+
|
| 82 |
+
a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
|
| 83 |
+
b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
|
| 84 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
| 85 |
+
Section 6 – Term and Termination.
|
| 86 |
+
|
| 87 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
| 88 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
| 89 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
| 90 |
+
2. upon express reinstatement by the Licensor.
|
| 91 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
| 92 |
+
|
| 93 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
| 94 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
| 95 |
+
Section 7 – Other Terms and Conditions.
|
| 96 |
+
|
| 97 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
| 98 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
| 99 |
+
Section 8 – Interpretation.
|
| 100 |
+
|
| 101 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
| 102 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
| 103 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
| 104 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
| 105 |
+
Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
| 106 |
+
|
| 107 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
|
@@ -1,12 +1,195 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji: 🐨
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: blue
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.35.0
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: vtontry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 4.41.0
|
| 6 |
---
|
| 7 |
+
# [ICLR 25]🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models
|
| 8 |
+
|
| 9 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 10 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
| 11 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
| 12 |
+
</a>
|
| 13 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
| 14 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
| 15 |
+
</a>
|
| 16 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
| 17 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
| 18 |
+
</a>
|
| 19 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 20 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 23 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 24 |
+
</a>
|
| 25 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
| 26 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
| 27 |
+
</a>
|
| 28 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
| 29 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
| 30 |
+
</a>
|
| 31 |
+
</div>
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
**CatVTON** is a simple and efficient virtual try-on diffusion model with ***1) Lightweight Network (899.06M parameters totally)***, ***2) Parameter-Efficient Training (49.57M parameters trainable)*** and ***3) Simplified Inference (< 8G VRAM for 1024X768 resolution)***.
|
| 35 |
+
<div align="center">
|
| 36 |
+
<img src="resource/img/teaser.jpg" width="100%" height="100%"/>
|
| 37 |
+
</div>
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## Updates
|
| 42 |
+
- **`2025/02/24`**: 🎉 We are excited to announce [**CatV2TON**](https://github.com/Zheng-Chong/CatV2TON), our new DiT-based model that supports both **image and video try-on**! Check it out!
|
| 43 |
+
- **`2025/02/20`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) has been updated to v2, which includes more details.
|
| 44 |
+
- **`2025/01/24`**: 🥳 CatVTON has been accepted to **ICLR 2025**!
|
| 45 |
+
- **`2024/12/20`**: 😄 Code for gradio app of **CatVTON-FLUX** has been released! It is not a stable version, but it is a good start!
|
| 46 |
+
- **`2024/12/19`**: [**CatVTON-FLUX**](https://huggingface.co/spaces/zhengchong/CatVTON) has been released! It is a extremely lightweight LoRA (only 37.4M checkpints) for [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), the lora weights are available in **[huggingface repo](https://huggingface.co/zhengchong/CatVTON/tree/main/flux-lora)**, code will be released soon!
|
| 47 |
+
- **`2024/11/26`**: Our **unified vision-based model for image and video try-on** will be released soon, bringing a brand-new virtual try-on experience! While our demo page will be temporarily taken offline, [**the demo on HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) will remain available for use !
|
| 48 |
+
- **`2024/10/17`**:[**Mask-free version**](https://huggingface.co/zhengchong/CatVTON-MaskFree)🤗 of CatVTON is release !
|
| 49 |
+
- **`2024/10/13`**: We have built a repo [**Awesome-Try-On-Models**](https://github.com/Zheng-Chong/Awesome-Try-On-Models) that focuses on image, video, and 3D-based try-on models published after 2023, aiming to provide insights into the latest technological trends. If you're interested, feel free to contribute or give it a 🌟 star!
|
| 50 |
+
- **`2024/08/13`**: We localize DensePose & SCHP to avoid certain environment issues.
|
| 51 |
+
- **`2024/08/10`**: Our 🤗 [**HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) is available now! Thanks for the grant from [**ZeroGPU**](https://huggingface.co/zero-gpu-explorers)!
|
| 52 |
+
- **`2024/08/09`**: [**Evaluation code**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#3-calculate-metrics) is provided to calculate metrics 📚.
|
| 53 |
+
- **`2024/07/27`**: We provide code and workflow for deploying CatVTON on [**ComfyUI**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#comfyui-workflow) 💥.
|
| 54 |
+
- **`2024/07/24`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) is available 🥳!
|
| 55 |
+
- **`2024/07/22`**: Our [**App Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/app.py) is released, deploy and enjoy CatVTON on your mechine 🎉!
|
| 56 |
+
- **`2024/07/21`**: Our [**Inference Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/inference.py) and [**Weights** 🤗](https://huggingface.co/zhengchong/CatVTON) are released.
|
| 57 |
+
- **`2024/07/11`**: Our [**Online Demo**](https://huggingface.co/spaces/zhengchong/CatVTON) is released 😁.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
## Installation
|
| 63 |
+
|
| 64 |
+
Create a conda environment & Install requirments
|
| 65 |
+
```shell
|
| 66 |
+
conda create -n catvton python==3.9.0
|
| 67 |
+
conda activate catvton
|
| 68 |
+
cd CatVTON-main # or your path to CatVTON project dir
|
| 69 |
+
pip install -r requirements.txt
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Deployment
|
| 73 |
+
### ComfyUI Workflow
|
| 74 |
+
We have modified the main code to enable easy deployment of CatVTON on [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Due to the incompatibility of the code structure, we have released this part in the [Releases](https://github.com/Zheng-Chong/CatVTON/releases/tag/ComfyUI), which includes the code placed under `custom_nodes` of ComfyUI and our workflow JSON files.
|
| 75 |
+
|
| 76 |
+
To deploy CatVTON to your ComfyUI, follow these steps:
|
| 77 |
+
1. Install all the requirements for both CatVTON and ComfyUI, refer to [Installation Guide for CatVTON](https://github.com/Zheng-Chong/CatVTON/blob/main/INSTALL.md) and [Installation Guide for ComfyUI](https://github.com/comfyanonymous/ComfyUI?tab=readme-ov-file#installing).
|
| 78 |
+
2. Download [`ComfyUI-CatVTON.zip`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/ComfyUI-CatVTON.zip) and unzip it in the `custom_nodes` folder under your ComfyUI project (clone from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)).
|
| 79 |
+
3. Run the ComfyUI.
|
| 80 |
+
4. Download [`catvton_workflow.json`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/catvton_workflow.json) and drag it into you ComfyUI webpage and enjoy 😆!
|
| 81 |
+
|
| 82 |
+
> Problems under Windows OS, please refer to [issue#8](https://github.com/Zheng-Chong/CatVTON/issues/8).
|
| 83 |
+
>
|
| 84 |
+
When you run the CatVTON workflow for the first time, the weight files will be automatically downloaded, usually taking dozens of minutes.
|
| 85 |
+
|
| 86 |
+
<div align="center">
|
| 87 |
+
<img src="resource/img/comfyui-1.png" width="100%" height="100%"/>
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
<!-- <div align="center">
|
| 91 |
+
<img src="resource/img/comfyui.png" width="100%" height="100%"/>
|
| 92 |
+
</div> -->
|
| 93 |
+
|
| 94 |
+
### Gradio App
|
| 95 |
+
|
| 96 |
+
To deploy the Gradio App for CatVTON on your machine, run the following command, and checkpoints will be automatically downloaded from HuggingFace.
|
| 97 |
+
|
| 98 |
+
```PowerShell
|
| 99 |
+
CUDA_VISIBLE_DEVICES=0 python app.py \
|
| 100 |
+
--output_dir="resource/demo/output" \
|
| 101 |
+
--mixed_precision="bf16" \
|
| 102 |
+
--allow_tf32
|
| 103 |
+
```
|
| 104 |
+
When using `bf16` precision, generating results with a resolution of `1024x768` only requires about `8G` VRAM.
|
| 105 |
+
|
| 106 |
+
## Inference
|
| 107 |
+
### 1. Data Preparation
|
| 108 |
+
Before inference, you need to download the [VITON-HD](https://github.com/shadow2496/VITON-HD) or [DressCode](https://github.com/aimagelab/dress-code) dataset.
|
| 109 |
+
Once the datasets are downloaded, the folder structures should look like these:
|
| 110 |
+
```
|
| 111 |
+
├── VITON-HD
|
| 112 |
+
| ├── test_pairs_unpaired.txt
|
| 113 |
+
│ ├── test
|
| 114 |
+
| | ├── image
|
| 115 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 116 |
+
│ │ ├── cloth
|
| 117 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
| 118 |
+
│ │ ├── agnostic-mask
|
| 119 |
+
│ │ │ ├── [000006_00_mask.png | 000008_00.png | ...]
|
| 120 |
+
...
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
├── DressCode
|
| 125 |
+
| ├── test_pairs_paired.txt
|
| 126 |
+
| ├── test_pairs_unpaired.txt
|
| 127 |
+
│ ├── [dresses | lower_body | upper_body]
|
| 128 |
+
| | ├── test_pairs_paired.txt
|
| 129 |
+
| | ├── test_pairs_unpaired.txt
|
| 130 |
+
│ │ ├── images
|
| 131 |
+
│ │ │ ├── [013563_0.jpg | 013563_1.jpg | 013564_0.jpg | 013564_1.jpg | ...]
|
| 132 |
+
│ │ ├── agnostic_masks
|
| 133 |
+
│ │ │ ├── [013563_0.png| 013564_0.png | ...]
|
| 134 |
+
...
|
| 135 |
+
```
|
| 136 |
+
For the DressCode dataset, we provide script to preprocessed agnostic masks, run the following command:
|
| 137 |
+
```PowerShell
|
| 138 |
+
CUDA_VISIBLE_DEVICES=0 python preprocess_agnostic_mask.py \
|
| 139 |
+
--data_root_path <your_path_to_DressCode>
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 2. Inference on VTIONHD/DressCode
|
| 143 |
+
To run the inference on the DressCode or VITON-HD dataset, run the following command, checkpoints will be automatically downloaded from HuggingFace.
|
| 144 |
+
|
| 145 |
+
```PowerShell
|
| 146 |
+
CUDA_VISIBLE_DEVICES=0 python inference.py \
|
| 147 |
+
--dataset [dresscode | vitonhd] \
|
| 148 |
+
--data_root_path <path> \
|
| 149 |
+
--output_dir <path>
|
| 150 |
+
--dataloader_num_workers 8 \
|
| 151 |
+
--batch_size 8 \
|
| 152 |
+
--seed 555 \
|
| 153 |
+
--mixed_precision [no | fp16 | bf16] \
|
| 154 |
+
--allow_tf32 \
|
| 155 |
+
--repaint \
|
| 156 |
+
--eval_pair
|
| 157 |
+
```
|
| 158 |
+
### 3. Calculate Metrics
|
| 159 |
+
|
| 160 |
+
After obtaining the inference results, calculate the metrics using the following command:
|
| 161 |
+
|
| 162 |
+
```PowerShell
|
| 163 |
+
CUDA_VISIBLE_DEVICES=0 python eval.py \
|
| 164 |
+
--gt_folder <your_path_to_gt_image_folder> \
|
| 165 |
+
--pred_folder <your_path_to_predicted_image_folder> \
|
| 166 |
+
--paired \
|
| 167 |
+
--batch_size=16 \
|
| 168 |
+
--num_workers=16
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
- `--gt_folder` and `--pred_folder` should be folders that contain **only images**.
|
| 172 |
+
- To evaluate the results in a paired setting, use `--paired`; for an unpaired setting, simply omit it.
|
| 173 |
+
- `--batch_size` and `--num_workers` should be adjusted based on your machine.
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
## Acknowledgement
|
| 177 |
+
Our code is modified based on [Diffusers](https://github.com/huggingface/diffusers). We adopt [Stable Diffusion v1.5 inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as the base model. We use [SCHP](https://github.com/GoGoDuck912/Self-Correction-Human-Parsing/tree/master) and [DensePose](https://github.com/facebookresearch/DensePose) to automatically generate masks in our [Gradio](https://github.com/gradio-app/gradio) App and [ComfyUI](https://github.com/comfyanonymous/ComfyUI) workflow. Thanks to all the contributors!
|
| 178 |
+
|
| 179 |
+
## License
|
| 180 |
+
All the materials, including code, checkpoints, and demo, are made available under the [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. You are free to copy, redistribute, remix, transform, and build upon the project for non-commercial purposes, as long as you give appropriate credit and distribute your contributions under the same license.
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
## Citation
|
| 184 |
|
| 185 |
+
```bibtex
|
| 186 |
+
@misc{chong2024catvtonconcatenationneedvirtual,
|
| 187 |
+
title={CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models},
|
| 188 |
+
author={Zheng Chong and Xiao Dong and Haoxiang Li and Shiyue Zhang and Wenqing Zhang and Xujie Zhang and Hanqing Zhao and Xiaodan Liang},
|
| 189 |
+
year={2024},
|
| 190 |
+
eprint={2407.15886},
|
| 191 |
+
archivePrefix={arXiv},
|
| 192 |
+
primaryClass={cs.CV},
|
| 193 |
+
url={https://arxiv.org/abs/2407.15886},
|
| 194 |
+
}
|
| 195 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
# Set memory growth for MPS to prevent out of memory errors
|
| 13 |
+
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
|
| 14 |
+
|
| 15 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 16 |
+
from model.pipeline import CatVTONPipeline
|
| 17 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--base_model_path",
|
| 23 |
+
type=str,
|
| 24 |
+
default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
|
| 25 |
+
help=(
|
| 26 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 27 |
+
),
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--resume_path",
|
| 31 |
+
type=str,
|
| 32 |
+
default="zhengchong/CatVTON",
|
| 33 |
+
help=(
|
| 34 |
+
"The Path to the checkpoint of trained tryon model."
|
| 35 |
+
),
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--output_dir",
|
| 39 |
+
type=str,
|
| 40 |
+
default="resource/demo/output",
|
| 41 |
+
help="The output directory where the model predictions will be written.",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--width",
|
| 46 |
+
type=int,
|
| 47 |
+
default=768,
|
| 48 |
+
help=(
|
| 49 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 50 |
+
" resolution"
|
| 51 |
+
),
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--height",
|
| 55 |
+
type=int,
|
| 56 |
+
default=1024,
|
| 57 |
+
help=(
|
| 58 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 59 |
+
" resolution"
|
| 60 |
+
),
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--repaint",
|
| 64 |
+
action="store_true",
|
| 65 |
+
help="Whether to repaint the result image with the original background."
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--allow_tf32",
|
| 69 |
+
action="store_true",
|
| 70 |
+
default=True,
|
| 71 |
+
help=(
|
| 72 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 73 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--mixed_precision",
|
| 78 |
+
type=str,
|
| 79 |
+
default="no",
|
| 80 |
+
choices=["no", "fp16", "bf16"],
|
| 81 |
+
help=(
|
| 82 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 83 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 84 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 90 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 91 |
+
args.local_rank = env_local_rank
|
| 92 |
+
|
| 93 |
+
return args
|
| 94 |
+
|
| 95 |
+
def image_grid(imgs, rows, cols):
|
| 96 |
+
assert len(imgs) == rows * cols
|
| 97 |
+
|
| 98 |
+
w, h = imgs[0].size
|
| 99 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 100 |
+
|
| 101 |
+
for i, img in enumerate(imgs):
|
| 102 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 103 |
+
return grid
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
args = parse_args()
|
| 107 |
+
repo_path = snapshot_download(repo_id=args.resume_path)
|
| 108 |
+
|
| 109 |
+
# Auto-detect device (CUDA if available, otherwise CPU)
|
| 110 |
+
# Note: MPS is disabled due to memory and compatibility issues
|
| 111 |
+
if torch.cuda.is_available():
|
| 112 |
+
device = 'cuda'
|
| 113 |
+
else:
|
| 114 |
+
device = 'cpu'
|
| 115 |
+
print("Note: Running on CPU. This will be slower but more stable.")
|
| 116 |
+
print(f"Using device: {device}")
|
| 117 |
+
|
| 118 |
+
# Pipeline
|
| 119 |
+
pipeline = CatVTONPipeline(
|
| 120 |
+
base_ckpt=args.base_model_path,
|
| 121 |
+
attn_ckpt=repo_path,
|
| 122 |
+
attn_ckpt_version="mix",
|
| 123 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 124 |
+
use_tf32=args.allow_tf32 and torch.cuda.is_available(), # Only use TF32 if CUDA is available
|
| 125 |
+
device=device
|
| 126 |
+
)
|
| 127 |
+
# AutoMasker
|
| 128 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
| 129 |
+
automasker = AutoMasker(
|
| 130 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 131 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 132 |
+
device=device,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def submit_function(
|
| 136 |
+
person_image,
|
| 137 |
+
cloth_image,
|
| 138 |
+
cloth_type,
|
| 139 |
+
num_inference_steps,
|
| 140 |
+
guidance_scale,
|
| 141 |
+
seed,
|
| 142 |
+
show_type
|
| 143 |
+
):
|
| 144 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 145 |
+
mask = Image.open(mask).convert("L")
|
| 146 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 147 |
+
mask = None
|
| 148 |
+
else:
|
| 149 |
+
mask = np.array(mask)
|
| 150 |
+
mask[mask > 0] = 255
|
| 151 |
+
mask = Image.fromarray(mask)
|
| 152 |
+
|
| 153 |
+
tmp_folder = args.output_dir
|
| 154 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 155 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 156 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 157 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 158 |
+
|
| 159 |
+
generator = None
|
| 160 |
+
if seed != -1:
|
| 161 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 162 |
+
|
| 163 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 164 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 165 |
+
|
| 166 |
+
# Use default resolution
|
| 167 |
+
target_width = args.width
|
| 168 |
+
target_height = args.height
|
| 169 |
+
|
| 170 |
+
person_image = resize_and_crop(person_image, (target_width, target_height))
|
| 171 |
+
cloth_image = resize_and_padding(cloth_image, (target_width, target_height))
|
| 172 |
+
|
| 173 |
+
# Process mask
|
| 174 |
+
if mask is not None:
|
| 175 |
+
mask = resize_and_crop(mask, (target_width, target_height))
|
| 176 |
+
else:
|
| 177 |
+
mask = automasker(
|
| 178 |
+
person_image,
|
| 179 |
+
cloth_type
|
| 180 |
+
)['mask']
|
| 181 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 182 |
+
|
| 183 |
+
# Inference
|
| 184 |
+
# try:
|
| 185 |
+
result_image = pipeline(
|
| 186 |
+
image=person_image,
|
| 187 |
+
condition_image=cloth_image,
|
| 188 |
+
mask=mask,
|
| 189 |
+
num_inference_steps=num_inference_steps,
|
| 190 |
+
guidance_scale=guidance_scale,
|
| 191 |
+
generator=generator
|
| 192 |
+
)[0]
|
| 193 |
+
# except Exception as e:
|
| 194 |
+
# raise gr.Error(
|
| 195 |
+
# "An error occurred. Please try again later: {}".format(e)
|
| 196 |
+
# )
|
| 197 |
+
|
| 198 |
+
# Post-process
|
| 199 |
+
masked_person = vis_mask(person_image, mask)
|
| 200 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
| 201 |
+
save_result_image.save(result_save_path)
|
| 202 |
+
if show_type == "result only":
|
| 203 |
+
return result_image
|
| 204 |
+
else:
|
| 205 |
+
width, height = person_image.size
|
| 206 |
+
if show_type == "input & result":
|
| 207 |
+
condition_width = width // 2
|
| 208 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 209 |
+
else:
|
| 210 |
+
condition_width = width // 3
|
| 211 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
| 212 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 213 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 214 |
+
new_result_image.paste(conditions, (0, 0))
|
| 215 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 216 |
+
return new_result_image
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def person_example_fn(image_path):
|
| 220 |
+
return image_path
|
| 221 |
+
|
| 222 |
+
HEADER = """
|
| 223 |
+
<h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
|
| 224 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 225 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
| 226 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
| 227 |
+
</a>
|
| 228 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
| 229 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
| 230 |
+
</a>
|
| 231 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
| 232 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
| 233 |
+
</a>
|
| 234 |
+
<a href="http://120.76.142.206:8888" style="margin: 0 2px;">
|
| 235 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 236 |
+
</a>
|
| 237 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 238 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 239 |
+
</a>
|
| 240 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
| 241 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
| 242 |
+
</a>
|
| 243 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
| 244 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
| 245 |
+
</a>
|
| 246 |
+
</div>
|
| 247 |
+
<br>
|
| 248 |
+
· This demo and our weights are only for <span>Non-commercial Use</span>. <br>
|
| 249 |
+
· You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
|
| 250 |
+
· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
|
| 251 |
+
· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def app_gradio():
|
| 255 |
+
with gr.Blocks(title="CatVTON") as demo:
|
| 256 |
+
gr.Markdown(HEADER)
|
| 257 |
+
with gr.Row():
|
| 258 |
+
with gr.Column(scale=1, min_width=350):
|
| 259 |
+
with gr.Row():
|
| 260 |
+
image_path = gr.Image(
|
| 261 |
+
type="filepath",
|
| 262 |
+
interactive=True,
|
| 263 |
+
visible=False,
|
| 264 |
+
)
|
| 265 |
+
person_image = gr.ImageEditor(
|
| 266 |
+
interactive=True, label="Person Image", type="filepath"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
with gr.Row():
|
| 270 |
+
with gr.Column(scale=1, min_width=230):
|
| 271 |
+
cloth_image = gr.Image(
|
| 272 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 273 |
+
)
|
| 274 |
+
with gr.Column(scale=1, min_width=120):
|
| 275 |
+
gr.Markdown(
|
| 276 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 277 |
+
)
|
| 278 |
+
cloth_type = gr.Radio(
|
| 279 |
+
label="Try-On Cloth Type",
|
| 280 |
+
choices=["upper", "lower", "overall"],
|
| 281 |
+
value="upper",
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
submit = gr.Button("Submit")
|
| 286 |
+
gr.Markdown(
|
| 287 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
gr.Markdown(
|
| 291 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 292 |
+
)
|
| 293 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 294 |
+
num_inference_steps = gr.Slider(
|
| 295 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 296 |
+
)
|
| 297 |
+
# Guidence Scale
|
| 298 |
+
guidance_scale = gr.Slider(
|
| 299 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 300 |
+
)
|
| 301 |
+
# Random Seed
|
| 302 |
+
seed = gr.Slider(
|
| 303 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 304 |
+
)
|
| 305 |
+
show_type = gr.Radio(
|
| 306 |
+
label="Show Type",
|
| 307 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 308 |
+
value="input & mask & result",
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
with gr.Column(scale=2, min_width=500):
|
| 312 |
+
result_image = gr.Image(interactive=False, label="Result")
|
| 313 |
+
with gr.Row():
|
| 314 |
+
# Photo Examples
|
| 315 |
+
root_path = "resource/demo/example"
|
| 316 |
+
with gr.Column():
|
| 317 |
+
men_exm = gr.Examples(
|
| 318 |
+
examples=[
|
| 319 |
+
os.path.join(root_path, "person", "men", _)
|
| 320 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 321 |
+
],
|
| 322 |
+
examples_per_page=4,
|
| 323 |
+
inputs=image_path,
|
| 324 |
+
label="Person Examples ①",
|
| 325 |
+
)
|
| 326 |
+
women_exm = gr.Examples(
|
| 327 |
+
examples=[
|
| 328 |
+
os.path.join(root_path, "person", "women", _)
|
| 329 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 330 |
+
],
|
| 331 |
+
examples_per_page=4,
|
| 332 |
+
inputs=image_path,
|
| 333 |
+
label="Person Examples ②",
|
| 334 |
+
)
|
| 335 |
+
gr.Markdown(
|
| 336 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 337 |
+
)
|
| 338 |
+
with gr.Column():
|
| 339 |
+
condition_upper_exm = gr.Examples(
|
| 340 |
+
examples=[
|
| 341 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 342 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 343 |
+
],
|
| 344 |
+
examples_per_page=4,
|
| 345 |
+
inputs=cloth_image,
|
| 346 |
+
label="Condition Upper Examples",
|
| 347 |
+
)
|
| 348 |
+
condition_overall_exm = gr.Examples(
|
| 349 |
+
examples=[
|
| 350 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 351 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 352 |
+
],
|
| 353 |
+
examples_per_page=4,
|
| 354 |
+
inputs=cloth_image,
|
| 355 |
+
label="Condition Overall Examples",
|
| 356 |
+
)
|
| 357 |
+
condition_person_exm = gr.Examples(
|
| 358 |
+
examples=[
|
| 359 |
+
os.path.join(root_path, "condition", "person", _)
|
| 360 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 361 |
+
],
|
| 362 |
+
examples_per_page=4,
|
| 363 |
+
inputs=cloth_image,
|
| 364 |
+
label="Condition Reference Person Examples",
|
| 365 |
+
)
|
| 366 |
+
gr.Markdown(
|
| 367 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
image_path.change(
|
| 371 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
submit.click(
|
| 375 |
+
submit_function,
|
| 376 |
+
[
|
| 377 |
+
person_image,
|
| 378 |
+
cloth_image,
|
| 379 |
+
cloth_type,
|
| 380 |
+
num_inference_steps,
|
| 381 |
+
guidance_scale,
|
| 382 |
+
seed,
|
| 383 |
+
show_type,
|
| 384 |
+
],
|
| 385 |
+
result_image,
|
| 386 |
+
)
|
| 387 |
+
demo.queue().launch(share=True, show_error=True)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
if __name__ == "__main__":
|
| 391 |
+
app_gradio()
|
app_flux.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 13 |
+
from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
|
| 14 |
+
from utils import resize_and_crop, resize_and_padding
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
parser = argparse.ArgumentParser(description="FLUX Try-On Demo")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--base_model_path",
|
| 20 |
+
type=str,
|
| 21 |
+
# default="black-forest-labs/FLUX.1-Fill-dev",
|
| 22 |
+
default="Models/FLUX.1-Fill-dev",
|
| 23 |
+
help="The path to the base model to use for evaluation."
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--resume_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default="zhengchong/CatVTON",
|
| 29 |
+
help="The Path to the checkpoint of trained tryon model."
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--output_dir",
|
| 33 |
+
type=str,
|
| 34 |
+
default="resource/demo/output",
|
| 35 |
+
help="The output directory where the model predictions will be written."
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--mixed_precision",
|
| 39 |
+
type=str,
|
| 40 |
+
default="bf16",
|
| 41 |
+
choices=["no", "fp16", "bf16"],
|
| 42 |
+
help="Whether to use mixed precision."
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--allow_tf32",
|
| 46 |
+
action="store_true",
|
| 47 |
+
default=True,
|
| 48 |
+
help="Whether or not to allow TF32 on Ampere GPUs."
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--width",
|
| 52 |
+
type=int,
|
| 53 |
+
default=768,
|
| 54 |
+
help="The width of the input image."
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--height",
|
| 58 |
+
type=int,
|
| 59 |
+
default=1024,
|
| 60 |
+
help="The height of the input image."
|
| 61 |
+
)
|
| 62 |
+
return parser.parse_args()
|
| 63 |
+
|
| 64 |
+
def image_grid(imgs, rows, cols):
|
| 65 |
+
assert len(imgs) == rows * cols
|
| 66 |
+
w, h = imgs[0].size
|
| 67 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 68 |
+
for i, img in enumerate(imgs):
|
| 69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 70 |
+
return grid
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def submit_function_flux(
|
| 74 |
+
person_image,
|
| 75 |
+
cloth_image,
|
| 76 |
+
cloth_type,
|
| 77 |
+
num_inference_steps,
|
| 78 |
+
guidance_scale,
|
| 79 |
+
seed,
|
| 80 |
+
show_type
|
| 81 |
+
):
|
| 82 |
+
|
| 83 |
+
# Process image editor input
|
| 84 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 85 |
+
mask = Image.open(mask).convert("L")
|
| 86 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 87 |
+
mask = None
|
| 88 |
+
else:
|
| 89 |
+
mask = np.array(mask)
|
| 90 |
+
mask[mask > 0] = 255
|
| 91 |
+
mask = Image.fromarray(mask)
|
| 92 |
+
|
| 93 |
+
# Set random seed
|
| 94 |
+
generator = None
|
| 95 |
+
if seed != -1:
|
| 96 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 97 |
+
|
| 98 |
+
# Process input images
|
| 99 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 100 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 101 |
+
|
| 102 |
+
# Adjust image sizes
|
| 103 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 104 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 105 |
+
|
| 106 |
+
# Process mask
|
| 107 |
+
if mask is not None:
|
| 108 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 109 |
+
else:
|
| 110 |
+
mask = automasker(
|
| 111 |
+
person_image,
|
| 112 |
+
cloth_type
|
| 113 |
+
)['mask']
|
| 114 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 115 |
+
|
| 116 |
+
# Inference
|
| 117 |
+
result_image = pipeline_flux(
|
| 118 |
+
image=person_image,
|
| 119 |
+
condition_image=cloth_image,
|
| 120 |
+
mask_image=mask,
|
| 121 |
+
height=args.height,
|
| 122 |
+
width=args.width,
|
| 123 |
+
num_inference_steps=num_inference_steps,
|
| 124 |
+
guidance_scale=guidance_scale,
|
| 125 |
+
generator=generator
|
| 126 |
+
).images[0]
|
| 127 |
+
|
| 128 |
+
# Post-processing
|
| 129 |
+
masked_person = vis_mask(person_image, mask)
|
| 130 |
+
|
| 131 |
+
# Return result based on show type
|
| 132 |
+
if show_type == "result only":
|
| 133 |
+
return result_image
|
| 134 |
+
else:
|
| 135 |
+
width, height = person_image.size
|
| 136 |
+
if show_type == "input & result":
|
| 137 |
+
condition_width = width // 2
|
| 138 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 139 |
+
else:
|
| 140 |
+
condition_width = width // 3
|
| 141 |
+
conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
|
| 142 |
+
|
| 143 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 144 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 145 |
+
new_result_image.paste(conditions, (0, 0))
|
| 146 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 147 |
+
return new_result_image
|
| 148 |
+
|
| 149 |
+
def person_example_fn(image_path):
|
| 150 |
+
return image_path
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def app_gradio():
|
| 154 |
+
with gr.Blocks(title="CatVTON with FLUX.1-Fill-dev") as demo:
|
| 155 |
+
gr.Markdown("# CatVTON with FLUX.1-Fill-dev")
|
| 156 |
+
with gr.Row():
|
| 157 |
+
with gr.Column(scale=1, min_width=350):
|
| 158 |
+
with gr.Row():
|
| 159 |
+
image_path_flux = gr.Image(
|
| 160 |
+
type="filepath",
|
| 161 |
+
interactive=True,
|
| 162 |
+
visible=False,
|
| 163 |
+
)
|
| 164 |
+
person_image_flux = gr.ImageEditor(
|
| 165 |
+
interactive=True, label="Person Image", type="filepath"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
with gr.Column(scale=1, min_width=230):
|
| 170 |
+
cloth_image_flux = gr.Image(
|
| 171 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 172 |
+
)
|
| 173 |
+
with gr.Column(scale=1, min_width=120):
|
| 174 |
+
gr.Markdown(
|
| 175 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 176 |
+
)
|
| 177 |
+
cloth_type = gr.Radio(
|
| 178 |
+
label="Try-On Cloth Type",
|
| 179 |
+
choices=["upper", "lower", "overall"],
|
| 180 |
+
value="upper",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
submit_flux = gr.Button("Submit")
|
| 184 |
+
gr.Markdown(
|
| 185 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 189 |
+
num_inference_steps_flux = gr.Slider(
|
| 190 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 191 |
+
)
|
| 192 |
+
# Guidence Scale
|
| 193 |
+
guidance_scale_flux = gr.Slider(
|
| 194 |
+
label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
|
| 195 |
+
)
|
| 196 |
+
# Random Seed
|
| 197 |
+
seed_flux = gr.Slider(
|
| 198 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 199 |
+
)
|
| 200 |
+
show_type = gr.Radio(
|
| 201 |
+
label="Show Type",
|
| 202 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 203 |
+
value="input & mask & result",
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
with gr.Column(scale=2, min_width=500):
|
| 207 |
+
result_image_flux = gr.Image(interactive=False, label="Result")
|
| 208 |
+
with gr.Row():
|
| 209 |
+
# Photo Examples
|
| 210 |
+
root_path = "resource/demo/example"
|
| 211 |
+
with gr.Column():
|
| 212 |
+
gr.Examples(
|
| 213 |
+
examples=[
|
| 214 |
+
os.path.join(root_path, "person", "men", _)
|
| 215 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 216 |
+
],
|
| 217 |
+
examples_per_page=4,
|
| 218 |
+
inputs=image_path_flux,
|
| 219 |
+
label="Person Examples ①",
|
| 220 |
+
)
|
| 221 |
+
gr.Examples(
|
| 222 |
+
examples=[
|
| 223 |
+
os.path.join(root_path, "person", "women", _)
|
| 224 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 225 |
+
],
|
| 226 |
+
examples_per_page=4,
|
| 227 |
+
inputs=image_path_flux,
|
| 228 |
+
label="Person Examples ②",
|
| 229 |
+
)
|
| 230 |
+
gr.Markdown(
|
| 231 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 232 |
+
)
|
| 233 |
+
with gr.Column():
|
| 234 |
+
gr.Examples(
|
| 235 |
+
examples=[
|
| 236 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 237 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 238 |
+
],
|
| 239 |
+
examples_per_page=4,
|
| 240 |
+
inputs=cloth_image_flux,
|
| 241 |
+
label="Condition Upper Examples",
|
| 242 |
+
)
|
| 243 |
+
gr.Examples(
|
| 244 |
+
examples=[
|
| 245 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 246 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 247 |
+
],
|
| 248 |
+
examples_per_page=4,
|
| 249 |
+
inputs=cloth_image_flux,
|
| 250 |
+
label="Condition Overall Examples",
|
| 251 |
+
)
|
| 252 |
+
condition_person_exm = gr.Examples(
|
| 253 |
+
examples=[
|
| 254 |
+
os.path.join(root_path, "condition", "person", _)
|
| 255 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 256 |
+
],
|
| 257 |
+
examples_per_page=4,
|
| 258 |
+
inputs=cloth_image_flux,
|
| 259 |
+
label="Condition Reference Person Examples",
|
| 260 |
+
)
|
| 261 |
+
gr.Markdown(
|
| 262 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
image_path_flux.change(
|
| 267 |
+
person_example_fn, inputs=image_path_flux, outputs=person_image_flux
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
submit_flux.click(
|
| 271 |
+
submit_function_flux,
|
| 272 |
+
[person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
|
| 273 |
+
result_image_flux,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
demo.queue().launch(share=True, show_error=True)
|
| 278 |
+
|
| 279 |
+
# 解析参数
|
| 280 |
+
args = parse_args()
|
| 281 |
+
|
| 282 |
+
# 加载模型
|
| 283 |
+
repo_path = snapshot_download(repo_id=args.resume_path)
|
| 284 |
+
pipeline_flux = FluxTryOnPipeline.from_pretrained(args.base_model_path)
|
| 285 |
+
pipeline_flux.load_lora_weights(
|
| 286 |
+
os.path.join(repo_path, "flux-lora"),
|
| 287 |
+
weight_name='pytorch_lora_weights.safetensors'
|
| 288 |
+
)
|
| 289 |
+
pipeline_flux.to("cuda", torch.bfloat16)
|
| 290 |
+
|
| 291 |
+
# 初始化 AutoMasker
|
| 292 |
+
mask_processor = VaeImageProcessor(
|
| 293 |
+
vae_scale_factor=8,
|
| 294 |
+
do_normalize=False,
|
| 295 |
+
do_binarize=True,
|
| 296 |
+
do_convert_grayscale=True
|
| 297 |
+
)
|
| 298 |
+
automasker = AutoMasker(
|
| 299 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 300 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 301 |
+
device='cuda'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
app_gradio()
|
app_p2p.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 13 |
+
from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
|
| 14 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--p2p_base_model_path",
|
| 21 |
+
type=str,
|
| 22 |
+
default="timbrooks/instruct-pix2pix",
|
| 23 |
+
help=(
|
| 24 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 25 |
+
),
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--ip_base_model_path",
|
| 29 |
+
type=str,
|
| 30 |
+
default="booksforcharlie/stable-diffusion-inpainting",
|
| 31 |
+
help=(
|
| 32 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--p2p_resume_path",
|
| 37 |
+
type=str,
|
| 38 |
+
default="zhengchong/CatVTON-MaskFree",
|
| 39 |
+
help=(
|
| 40 |
+
"The Path to the checkpoint of trained tryon model."
|
| 41 |
+
),
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--ip_resume_path",
|
| 45 |
+
type=str,
|
| 46 |
+
default="zhengchong/CatVTON",
|
| 47 |
+
help=(
|
| 48 |
+
"The Path to the checkpoint of trained tryon model."
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--output_dir",
|
| 53 |
+
type=str,
|
| 54 |
+
default="resource/demo/output",
|
| 55 |
+
help="The output directory where the model predictions will be written.",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--width",
|
| 60 |
+
type=int,
|
| 61 |
+
default=768,
|
| 62 |
+
help=(
|
| 63 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 64 |
+
" resolution"
|
| 65 |
+
),
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--height",
|
| 69 |
+
type=int,
|
| 70 |
+
default=1024,
|
| 71 |
+
help=(
|
| 72 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 73 |
+
" resolution"
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--repaint",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="Whether to repaint the result image with the original background."
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--allow_tf32",
|
| 83 |
+
action="store_true",
|
| 84 |
+
default=True,
|
| 85 |
+
help=(
|
| 86 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 87 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--mixed_precision",
|
| 92 |
+
type=str,
|
| 93 |
+
default="bf16",
|
| 94 |
+
choices=["no", "fp16", "bf16"],
|
| 95 |
+
help=(
|
| 96 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 97 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 98 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 99 |
+
),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 104 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 105 |
+
args.local_rank = env_local_rank
|
| 106 |
+
|
| 107 |
+
return args
|
| 108 |
+
|
| 109 |
+
def image_grid(imgs, rows, cols):
|
| 110 |
+
assert len(imgs) == rows * cols
|
| 111 |
+
|
| 112 |
+
w, h = imgs[0].size
|
| 113 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 114 |
+
|
| 115 |
+
for i, img in enumerate(imgs):
|
| 116 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 117 |
+
return grid
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
args = parse_args()
|
| 121 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
| 122 |
+
# Pipeline
|
| 123 |
+
pipeline_p2p = CatVTONPix2PixPipeline(
|
| 124 |
+
base_ckpt=args.p2p_base_model_path,
|
| 125 |
+
attn_ckpt=repo_path,
|
| 126 |
+
attn_ckpt_version="mix-48k-1024",
|
| 127 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 128 |
+
use_tf32=args.allow_tf32,
|
| 129 |
+
device='cuda'
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Pipeline
|
| 133 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
| 134 |
+
pipeline = CatVTONPipeline(
|
| 135 |
+
base_ckpt=args.ip_base_model_path,
|
| 136 |
+
attn_ckpt=repo_path,
|
| 137 |
+
attn_ckpt_version="mix",
|
| 138 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 139 |
+
use_tf32=args.allow_tf32,
|
| 140 |
+
device='cuda'
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# AutoMasker
|
| 144 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
| 145 |
+
automasker = AutoMasker(
|
| 146 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 147 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 148 |
+
device='cuda',
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def submit_function_p2p(
|
| 153 |
+
person_image,
|
| 154 |
+
cloth_image,
|
| 155 |
+
num_inference_steps,
|
| 156 |
+
guidance_scale,
|
| 157 |
+
seed):
|
| 158 |
+
person_image= person_image["background"]
|
| 159 |
+
|
| 160 |
+
tmp_folder = args.output_dir
|
| 161 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 162 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 163 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 164 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 165 |
+
|
| 166 |
+
generator = None
|
| 167 |
+
if seed != -1:
|
| 168 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 169 |
+
|
| 170 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 171 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 172 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 173 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 174 |
+
|
| 175 |
+
# Inference
|
| 176 |
+
try:
|
| 177 |
+
result_image = pipeline_p2p(
|
| 178 |
+
image=person_image,
|
| 179 |
+
condition_image=cloth_image,
|
| 180 |
+
num_inference_steps=num_inference_steps,
|
| 181 |
+
guidance_scale=guidance_scale,
|
| 182 |
+
generator=generator
|
| 183 |
+
)[0]
|
| 184 |
+
except Exception as e:
|
| 185 |
+
raise gr.Error(
|
| 186 |
+
"An error occurred. Please try again later: {}".format(e)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Post-process
|
| 190 |
+
save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
|
| 191 |
+
save_result_image.save(result_save_path)
|
| 192 |
+
return result_image
|
| 193 |
+
|
| 194 |
+
def submit_function(
|
| 195 |
+
person_image,
|
| 196 |
+
cloth_image,
|
| 197 |
+
cloth_type,
|
| 198 |
+
num_inference_steps,
|
| 199 |
+
guidance_scale,
|
| 200 |
+
seed,
|
| 201 |
+
show_type
|
| 202 |
+
):
|
| 203 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 204 |
+
mask = Image.open(mask).convert("L")
|
| 205 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 206 |
+
mask = None
|
| 207 |
+
else:
|
| 208 |
+
mask = np.array(mask)
|
| 209 |
+
mask[mask > 0] = 255
|
| 210 |
+
mask = Image.fromarray(mask)
|
| 211 |
+
|
| 212 |
+
tmp_folder = args.output_dir
|
| 213 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 214 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 215 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 216 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 217 |
+
|
| 218 |
+
generator = None
|
| 219 |
+
if seed != -1:
|
| 220 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 221 |
+
|
| 222 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 223 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 224 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 225 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 226 |
+
|
| 227 |
+
# Process mask
|
| 228 |
+
if mask is not None:
|
| 229 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 230 |
+
else:
|
| 231 |
+
mask = automasker(
|
| 232 |
+
person_image,
|
| 233 |
+
cloth_type
|
| 234 |
+
)['mask']
|
| 235 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 236 |
+
|
| 237 |
+
# Inference
|
| 238 |
+
# try:
|
| 239 |
+
result_image = pipeline(
|
| 240 |
+
image=person_image,
|
| 241 |
+
condition_image=cloth_image,
|
| 242 |
+
mask=mask,
|
| 243 |
+
num_inference_steps=num_inference_steps,
|
| 244 |
+
guidance_scale=guidance_scale,
|
| 245 |
+
generator=generator
|
| 246 |
+
)[0]
|
| 247 |
+
# except Exception as e:
|
| 248 |
+
# raise gr.Error(
|
| 249 |
+
# "An error occurred. Please try again later: {}".format(e)
|
| 250 |
+
# )
|
| 251 |
+
|
| 252 |
+
# Post-process
|
| 253 |
+
masked_person = vis_mask(person_image, mask)
|
| 254 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
| 255 |
+
save_result_image.save(result_save_path)
|
| 256 |
+
if show_type == "result only":
|
| 257 |
+
return result_image
|
| 258 |
+
else:
|
| 259 |
+
width, height = person_image.size
|
| 260 |
+
if show_type == "input & result":
|
| 261 |
+
condition_width = width // 2
|
| 262 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 263 |
+
else:
|
| 264 |
+
condition_width = width // 3
|
| 265 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
| 266 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 267 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 268 |
+
new_result_image.paste(conditions, (0, 0))
|
| 269 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 270 |
+
return new_result_image
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def person_example_fn(image_path):
|
| 275 |
+
return image_path
|
| 276 |
+
|
| 277 |
+
HEADER = """
|
| 278 |
+
<h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
|
| 279 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 280 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
| 281 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
| 282 |
+
</a>
|
| 283 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
| 284 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
| 285 |
+
</a>
|
| 286 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
| 287 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
| 288 |
+
</a>
|
| 289 |
+
<a href="http://120.76.142.206:8888" style="margin: 0 2px;">
|
| 290 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 291 |
+
</a>
|
| 292 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 293 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 294 |
+
</a>
|
| 295 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
| 296 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
| 297 |
+
</a>
|
| 298 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
| 299 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
| 300 |
+
</a>
|
| 301 |
+
</div>
|
| 302 |
+
<br>
|
| 303 |
+
· This demo and our weights are only for <span>Non-commercial Use</span>. <br>
|
| 304 |
+
· You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
|
| 305 |
+
· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
|
| 306 |
+
· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def app_gradio():
|
| 310 |
+
with gr.Blocks(title="CatVTON") as demo:
|
| 311 |
+
gr.Markdown(HEADER)
|
| 312 |
+
with gr.Tab("Mask-based Virtual Try-On"):
|
| 313 |
+
with gr.Row():
|
| 314 |
+
with gr.Column(scale=1, min_width=350):
|
| 315 |
+
with gr.Row():
|
| 316 |
+
image_path = gr.Image(
|
| 317 |
+
type="filepath",
|
| 318 |
+
interactive=True,
|
| 319 |
+
visible=False,
|
| 320 |
+
)
|
| 321 |
+
person_image = gr.ImageEditor(
|
| 322 |
+
interactive=True, label="Person Image", type="filepath"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
with gr.Row():
|
| 326 |
+
with gr.Column(scale=1, min_width=230):
|
| 327 |
+
cloth_image = gr.Image(
|
| 328 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 329 |
+
)
|
| 330 |
+
with gr.Column(scale=1, min_width=120):
|
| 331 |
+
gr.Markdown(
|
| 332 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 333 |
+
)
|
| 334 |
+
cloth_type = gr.Radio(
|
| 335 |
+
label="Try-On Cloth Type",
|
| 336 |
+
choices=["upper", "lower", "overall"],
|
| 337 |
+
value="upper",
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
submit = gr.Button("Submit")
|
| 342 |
+
gr.Markdown(
|
| 343 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
gr.Markdown(
|
| 347 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 348 |
+
)
|
| 349 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 350 |
+
num_inference_steps = gr.Slider(
|
| 351 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 352 |
+
)
|
| 353 |
+
# Guidence Scale
|
| 354 |
+
guidance_scale = gr.Slider(
|
| 355 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 356 |
+
)
|
| 357 |
+
# Random Seed
|
| 358 |
+
seed = gr.Slider(
|
| 359 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 360 |
+
)
|
| 361 |
+
show_type = gr.Radio(
|
| 362 |
+
label="Show Type",
|
| 363 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 364 |
+
value="input & mask & result",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
with gr.Column(scale=2, min_width=500):
|
| 368 |
+
result_image = gr.Image(interactive=False, label="Result")
|
| 369 |
+
with gr.Row():
|
| 370 |
+
# Photo Examples
|
| 371 |
+
root_path = "resource/demo/example"
|
| 372 |
+
with gr.Column():
|
| 373 |
+
men_exm = gr.Examples(
|
| 374 |
+
examples=[
|
| 375 |
+
os.path.join(root_path, "person", "men", _)
|
| 376 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 377 |
+
],
|
| 378 |
+
examples_per_page=4,
|
| 379 |
+
inputs=image_path,
|
| 380 |
+
label="Person Examples ①",
|
| 381 |
+
)
|
| 382 |
+
women_exm = gr.Examples(
|
| 383 |
+
examples=[
|
| 384 |
+
os.path.join(root_path, "person", "women", _)
|
| 385 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 386 |
+
],
|
| 387 |
+
examples_per_page=4,
|
| 388 |
+
inputs=image_path,
|
| 389 |
+
label="Person Examples ②",
|
| 390 |
+
)
|
| 391 |
+
gr.Markdown(
|
| 392 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 393 |
+
)
|
| 394 |
+
with gr.Column():
|
| 395 |
+
condition_upper_exm = gr.Examples(
|
| 396 |
+
examples=[
|
| 397 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 398 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 399 |
+
],
|
| 400 |
+
examples_per_page=4,
|
| 401 |
+
inputs=cloth_image,
|
| 402 |
+
label="Condition Upper Examples",
|
| 403 |
+
)
|
| 404 |
+
condition_overall_exm = gr.Examples(
|
| 405 |
+
examples=[
|
| 406 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 407 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 408 |
+
],
|
| 409 |
+
examples_per_page=4,
|
| 410 |
+
inputs=cloth_image,
|
| 411 |
+
label="Condition Overall Examples",
|
| 412 |
+
)
|
| 413 |
+
condition_person_exm = gr.Examples(
|
| 414 |
+
examples=[
|
| 415 |
+
os.path.join(root_path, "condition", "person", _)
|
| 416 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 417 |
+
],
|
| 418 |
+
examples_per_page=4,
|
| 419 |
+
inputs=cloth_image,
|
| 420 |
+
label="Condition Reference Person Examples",
|
| 421 |
+
)
|
| 422 |
+
gr.Markdown(
|
| 423 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
image_path.change(
|
| 427 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
submit.click(
|
| 431 |
+
submit_function,
|
| 432 |
+
[
|
| 433 |
+
person_image,
|
| 434 |
+
cloth_image,
|
| 435 |
+
cloth_type,
|
| 436 |
+
num_inference_steps,
|
| 437 |
+
guidance_scale,
|
| 438 |
+
seed,
|
| 439 |
+
show_type,
|
| 440 |
+
],
|
| 441 |
+
result_image,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
with gr.Tab("Mask-Free Virtual Try-On"):
|
| 445 |
+
with gr.Row():
|
| 446 |
+
with gr.Column(scale=1, min_width=350):
|
| 447 |
+
with gr.Row():
|
| 448 |
+
image_path_p2p = gr.Image(
|
| 449 |
+
type="filepath",
|
| 450 |
+
interactive=True,
|
| 451 |
+
visible=False,
|
| 452 |
+
)
|
| 453 |
+
person_image_p2p = gr.ImageEditor(
|
| 454 |
+
interactive=True, label="Person Image", type="filepath"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
with gr.Row():
|
| 458 |
+
with gr.Column(scale=1, min_width=230):
|
| 459 |
+
cloth_image_p2p = gr.Image(
|
| 460 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
submit_p2p = gr.Button("Submit")
|
| 464 |
+
gr.Markdown(
|
| 465 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
gr.Markdown(
|
| 469 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 470 |
+
)
|
| 471 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 472 |
+
num_inference_steps_p2p = gr.Slider(
|
| 473 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 474 |
+
)
|
| 475 |
+
# Guidence Scale
|
| 476 |
+
guidance_scale_p2p = gr.Slider(
|
| 477 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 478 |
+
)
|
| 479 |
+
# Random Seed
|
| 480 |
+
seed_p2p = gr.Slider(
|
| 481 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 482 |
+
)
|
| 483 |
+
# show_type = gr.Radio(
|
| 484 |
+
# label="Show Type",
|
| 485 |
+
# choices=["result only", "input & result", "input & mask & result"],
|
| 486 |
+
# value="input & mask & result",
|
| 487 |
+
# )
|
| 488 |
+
|
| 489 |
+
with gr.Column(scale=2, min_width=500):
|
| 490 |
+
result_image_p2p = gr.Image(interactive=False, label="Result")
|
| 491 |
+
with gr.Row():
|
| 492 |
+
# Photo Examples
|
| 493 |
+
root_path = "resource/demo/example"
|
| 494 |
+
with gr.Column():
|
| 495 |
+
gr.Examples(
|
| 496 |
+
examples=[
|
| 497 |
+
os.path.join(root_path, "person", "men", _)
|
| 498 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 499 |
+
],
|
| 500 |
+
examples_per_page=4,
|
| 501 |
+
inputs=image_path_p2p,
|
| 502 |
+
label="Person Examples ①",
|
| 503 |
+
)
|
| 504 |
+
gr.Examples(
|
| 505 |
+
examples=[
|
| 506 |
+
os.path.join(root_path, "person", "women", _)
|
| 507 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 508 |
+
],
|
| 509 |
+
examples_per_page=4,
|
| 510 |
+
inputs=image_path_p2p,
|
| 511 |
+
label="Person Examples ②",
|
| 512 |
+
)
|
| 513 |
+
gr.Markdown(
|
| 514 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 515 |
+
)
|
| 516 |
+
with gr.Column():
|
| 517 |
+
gr.Examples(
|
| 518 |
+
examples=[
|
| 519 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 520 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 521 |
+
],
|
| 522 |
+
examples_per_page=4,
|
| 523 |
+
inputs=cloth_image_p2p,
|
| 524 |
+
label="Condition Upper Examples",
|
| 525 |
+
)
|
| 526 |
+
gr.Examples(
|
| 527 |
+
examples=[
|
| 528 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 529 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 530 |
+
],
|
| 531 |
+
examples_per_page=4,
|
| 532 |
+
inputs=cloth_image_p2p,
|
| 533 |
+
label="Condition Overall Examples",
|
| 534 |
+
)
|
| 535 |
+
condition_person_exm = gr.Examples(
|
| 536 |
+
examples=[
|
| 537 |
+
os.path.join(root_path, "condition", "person", _)
|
| 538 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 539 |
+
],
|
| 540 |
+
examples_per_page=4,
|
| 541 |
+
inputs=cloth_image_p2p,
|
| 542 |
+
label="Condition Reference Person Examples",
|
| 543 |
+
)
|
| 544 |
+
gr.Markdown(
|
| 545 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
image_path_p2p.change(
|
| 549 |
+
person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
submit_p2p.click(
|
| 553 |
+
submit_function_p2p,
|
| 554 |
+
[
|
| 555 |
+
person_image_p2p,
|
| 556 |
+
cloth_image_p2p,
|
| 557 |
+
num_inference_steps_p2p,
|
| 558 |
+
guidance_scale_p2p,
|
| 559 |
+
seed_p2p],
|
| 560 |
+
result_image_p2p,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
demo.queue().launch(share=True, show_error=True)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if __name__ == "__main__":
|
| 567 |
+
app_gradio()
|
densepose/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from .data.datasets import builtin # just to register data
|
| 5 |
+
from .converters import builtin as builtin_converters # register converters
|
| 6 |
+
from .config import (
|
| 7 |
+
add_densepose_config,
|
| 8 |
+
add_densepose_head_config,
|
| 9 |
+
add_hrnet_config,
|
| 10 |
+
add_dataset_category_config,
|
| 11 |
+
add_bootstrap_config,
|
| 12 |
+
load_bootstrap_config,
|
| 13 |
+
)
|
| 14 |
+
from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
| 15 |
+
from .evaluation import DensePoseCOCOEvaluator
|
| 16 |
+
from .modeling.roi_heads import DensePoseROIHeads
|
| 17 |
+
from .modeling.test_time_augmentation import (
|
| 18 |
+
DensePoseGeneralizedRCNNWithTTA,
|
| 19 |
+
DensePoseDatasetMapperTTA,
|
| 20 |
+
)
|
| 21 |
+
from .utils.transform import load_from_cfg
|
| 22 |
+
from .modeling.hrfpn import build_hrfpn_backbone
|
densepose/config.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding = utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# pyre-ignore-all-errors
|
| 4 |
+
|
| 5 |
+
from detectron2.config import CfgNode as CN
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def add_dataset_category_config(cfg: CN) -> None:
|
| 9 |
+
"""
|
| 10 |
+
Add config for additional category-related dataset options
|
| 11 |
+
- category whitelisting
|
| 12 |
+
- category mapping
|
| 13 |
+
"""
|
| 14 |
+
_C = cfg
|
| 15 |
+
_C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
|
| 16 |
+
_C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
|
| 17 |
+
# class to mesh mapping
|
| 18 |
+
_C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def add_evaluation_config(cfg: CN) -> None:
|
| 22 |
+
_C = cfg
|
| 23 |
+
_C.DENSEPOSE_EVALUATION = CN()
|
| 24 |
+
# evaluator type, possible values:
|
| 25 |
+
# - "iou": evaluator for models that produce iou data
|
| 26 |
+
# - "cse": evaluator for models that produce cse data
|
| 27 |
+
_C.DENSEPOSE_EVALUATION.TYPE = "iou"
|
| 28 |
+
# storage for DensePose results, possible values:
|
| 29 |
+
# - "none": no explicit storage, all the results are stored in the
|
| 30 |
+
# dictionary with predictions, memory intensive;
|
| 31 |
+
# historically the default storage type
|
| 32 |
+
# - "ram": RAM storage, uses per-process RAM storage, which is
|
| 33 |
+
# reduced to a single process storage on later stages,
|
| 34 |
+
# less memory intensive
|
| 35 |
+
# - "file": file storage, uses per-process file-based storage,
|
| 36 |
+
# the least memory intensive, but may create bottlenecks
|
| 37 |
+
# on file system accesses
|
| 38 |
+
_C.DENSEPOSE_EVALUATION.STORAGE = "none"
|
| 39 |
+
# minimum threshold for IOU values: the lower its values is,
|
| 40 |
+
# the more matches are produced (and the higher the AP score)
|
| 41 |
+
_C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
|
| 42 |
+
# Non-distributed inference is slower (at inference time) but can avoid RAM OOM
|
| 43 |
+
_C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
|
| 44 |
+
# evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
|
| 45 |
+
_C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
|
| 46 |
+
# meshes to compute mesh alignment for
|
| 47 |
+
_C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def add_bootstrap_config(cfg: CN) -> None:
|
| 51 |
+
""" """
|
| 52 |
+
_C = cfg
|
| 53 |
+
_C.BOOTSTRAP_DATASETS = []
|
| 54 |
+
_C.BOOTSTRAP_MODEL = CN()
|
| 55 |
+
_C.BOOTSTRAP_MODEL.WEIGHTS = ""
|
| 56 |
+
_C.BOOTSTRAP_MODEL.DEVICE = "cuda"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_bootstrap_dataset_config() -> CN:
|
| 60 |
+
_C = CN()
|
| 61 |
+
_C.DATASET = ""
|
| 62 |
+
# ratio used to mix data loaders
|
| 63 |
+
_C.RATIO = 0.1
|
| 64 |
+
# image loader
|
| 65 |
+
_C.IMAGE_LOADER = CN(new_allowed=True)
|
| 66 |
+
_C.IMAGE_LOADER.TYPE = ""
|
| 67 |
+
_C.IMAGE_LOADER.BATCH_SIZE = 4
|
| 68 |
+
_C.IMAGE_LOADER.NUM_WORKERS = 4
|
| 69 |
+
_C.IMAGE_LOADER.CATEGORIES = []
|
| 70 |
+
_C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
|
| 71 |
+
_C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
|
| 72 |
+
# inference
|
| 73 |
+
_C.INFERENCE = CN()
|
| 74 |
+
# batch size for model inputs
|
| 75 |
+
_C.INFERENCE.INPUT_BATCH_SIZE = 4
|
| 76 |
+
# batch size to group model outputs
|
| 77 |
+
_C.INFERENCE.OUTPUT_BATCH_SIZE = 2
|
| 78 |
+
# sampled data
|
| 79 |
+
_C.DATA_SAMPLER = CN(new_allowed=True)
|
| 80 |
+
_C.DATA_SAMPLER.TYPE = ""
|
| 81 |
+
_C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
|
| 82 |
+
# filter
|
| 83 |
+
_C.FILTER = CN(new_allowed=True)
|
| 84 |
+
_C.FILTER.TYPE = ""
|
| 85 |
+
return _C
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_bootstrap_config(cfg: CN) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Bootstrap datasets are given as a list of `dict` that are not automatically
|
| 91 |
+
converted into CfgNode. This method processes all bootstrap dataset entries
|
| 92 |
+
and ensures that they are in CfgNode format and comply with the specification
|
| 93 |
+
"""
|
| 94 |
+
if not cfg.BOOTSTRAP_DATASETS:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
bootstrap_datasets_cfgnodes = []
|
| 98 |
+
for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
|
| 99 |
+
_C = get_bootstrap_dataset_config().clone()
|
| 100 |
+
_C.merge_from_other_cfg(CN(dataset_cfg))
|
| 101 |
+
bootstrap_datasets_cfgnodes.append(_C)
|
| 102 |
+
cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def add_densepose_head_cse_config(cfg: CN) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Add configuration options for Continuous Surface Embeddings (CSE)
|
| 108 |
+
"""
|
| 109 |
+
_C = cfg
|
| 110 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
|
| 111 |
+
# Dimensionality D of the embedding space
|
| 112 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
|
| 113 |
+
# Embedder specifications for various mesh IDs
|
| 114 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
|
| 115 |
+
# normalization coefficient for embedding distances
|
| 116 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
|
| 117 |
+
# normalization coefficient for geodesic distances
|
| 118 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
|
| 119 |
+
# embedding loss weight
|
| 120 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
|
| 121 |
+
# embedding loss name, currently the following options are supported:
|
| 122 |
+
# - EmbeddingLoss: cross-entropy on vertex labels
|
| 123 |
+
# - SoftEmbeddingLoss: cross-entropy on vertex label combined with
|
| 124 |
+
# Gaussian penalty on distance between vertices
|
| 125 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
|
| 126 |
+
# optimizer hyperparameters
|
| 127 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
|
| 128 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
|
| 129 |
+
# Shape to shape cycle consistency loss parameters:
|
| 130 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 131 |
+
# shape to shape cycle consistency loss weight
|
| 132 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
|
| 133 |
+
# norm type used for loss computation
|
| 134 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 135 |
+
# normalization term for embedding similarity matrices
|
| 136 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
|
| 137 |
+
# maximum number of vertices to include into shape to shape cycle loss
|
| 138 |
+
# if negative or zero, all vertices are considered
|
| 139 |
+
# if positive, random subset of vertices of given size is considered
|
| 140 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
|
| 141 |
+
# Pixel to shape cycle consistency loss parameters:
|
| 142 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 143 |
+
# pixel to shape cycle consistency loss weight
|
| 144 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
|
| 145 |
+
# norm type used for loss computation
|
| 146 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 147 |
+
# map images to all meshes and back (if false, use only gt meshes from the batch)
|
| 148 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
|
| 149 |
+
# Randomly select at most this number of pixels from every instance
|
| 150 |
+
# if negative or zero, all vertices are considered
|
| 151 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
|
| 152 |
+
# normalization factor for pixel to pixel distances (higher value = smoother distribution)
|
| 153 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
|
| 154 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
|
| 155 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def add_densepose_head_config(cfg: CN) -> None:
|
| 159 |
+
"""
|
| 160 |
+
Add config for densepose head.
|
| 161 |
+
"""
|
| 162 |
+
_C = cfg
|
| 163 |
+
|
| 164 |
+
_C.MODEL.DENSEPOSE_ON = True
|
| 165 |
+
|
| 166 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD = CN()
|
| 167 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
|
| 168 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
|
| 169 |
+
# Number of parts used for point labels
|
| 170 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
|
| 171 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
|
| 172 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
|
| 173 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
|
| 174 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
|
| 175 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
|
| 176 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 177 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
|
| 178 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
|
| 179 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
|
| 180 |
+
# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
|
| 181 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
|
| 182 |
+
# Loss weights for annotation masks.(14 Parts)
|
| 183 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
|
| 184 |
+
# Loss weights for surface parts. (24 Parts)
|
| 185 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
|
| 186 |
+
# Loss weights for UV regression.
|
| 187 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
|
| 188 |
+
# Coarse segmentation is trained using instance segmentation task data
|
| 189 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
|
| 190 |
+
# For Decoder
|
| 191 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
|
| 192 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
|
| 193 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
|
| 194 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
|
| 195 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
|
| 196 |
+
# For DeepLab head
|
| 197 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
|
| 198 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
|
| 199 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
|
| 200 |
+
# Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
|
| 201 |
+
# Some registered predictors:
|
| 202 |
+
# "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
|
| 203 |
+
# "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
|
| 204 |
+
# and associated confidences for predefined charts (default)
|
| 205 |
+
# "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
|
| 206 |
+
# and associated confidences for CSE
|
| 207 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
|
| 208 |
+
# Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
|
| 209 |
+
# Some registered losses:
|
| 210 |
+
# "DensePoseChartLoss": loss for chart-based models that estimate
|
| 211 |
+
# segmentation and UV coordinates
|
| 212 |
+
# "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
|
| 213 |
+
# segmentation, UV coordinates and the corresponding confidences (default)
|
| 214 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
|
| 215 |
+
# Confidences
|
| 216 |
+
# Enable learning UV confidences (variances) along with the actual values
|
| 217 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
|
| 218 |
+
# UV confidence lower bound
|
| 219 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
|
| 220 |
+
# Enable learning segmentation confidences (variances) along with the actual values
|
| 221 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
|
| 222 |
+
# Segmentation confidence lower bound
|
| 223 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
|
| 224 |
+
# Statistical model type for confidence learning, possible values:
|
| 225 |
+
# - "iid_iso": statistically independent identically distributed residuals
|
| 226 |
+
# with isotropic covariance
|
| 227 |
+
# - "indep_aniso": statistically independent residuals with anisotropic
|
| 228 |
+
# covariances
|
| 229 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
|
| 230 |
+
# List of angles for rotation in data augmentation during training
|
| 231 |
+
_C.INPUT.ROTATION_ANGLES = [0]
|
| 232 |
+
_C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
|
| 233 |
+
|
| 234 |
+
add_densepose_head_cse_config(cfg)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def add_hrnet_config(cfg: CN) -> None:
|
| 238 |
+
"""
|
| 239 |
+
Add config for HRNet backbone.
|
| 240 |
+
"""
|
| 241 |
+
_C = cfg
|
| 242 |
+
|
| 243 |
+
# For HigherHRNet w32
|
| 244 |
+
_C.MODEL.HRNET = CN()
|
| 245 |
+
_C.MODEL.HRNET.STEM_INPLANES = 64
|
| 246 |
+
_C.MODEL.HRNET.STAGE2 = CN()
|
| 247 |
+
_C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
|
| 248 |
+
_C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
|
| 249 |
+
_C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
|
| 250 |
+
_C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
| 251 |
+
_C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
|
| 252 |
+
_C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
|
| 253 |
+
_C.MODEL.HRNET.STAGE3 = CN()
|
| 254 |
+
_C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
|
| 255 |
+
_C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
|
| 256 |
+
_C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
|
| 257 |
+
_C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
| 258 |
+
_C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
|
| 259 |
+
_C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
|
| 260 |
+
_C.MODEL.HRNET.STAGE4 = CN()
|
| 261 |
+
_C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
|
| 262 |
+
_C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
|
| 263 |
+
_C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
|
| 264 |
+
_C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
| 265 |
+
_C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
|
| 266 |
+
_C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
|
| 267 |
+
|
| 268 |
+
_C.MODEL.HRNET.HRFPN = CN()
|
| 269 |
+
_C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def add_densepose_config(cfg: CN) -> None:
|
| 273 |
+
add_densepose_head_config(cfg)
|
| 274 |
+
add_hrnet_config(cfg)
|
| 275 |
+
add_bootstrap_config(cfg)
|
| 276 |
+
add_dataset_category_config(cfg)
|
| 277 |
+
add_evaluation_config(cfg)
|
densepose/converters/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .hflip import HFlipConverter
|
| 6 |
+
from .to_mask import ToMaskConverter
|
| 7 |
+
from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
|
| 8 |
+
from .segm_to_mask import (
|
| 9 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 10 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 11 |
+
resample_fine_and_coarse_segm_to_bbox,
|
| 12 |
+
)
|
| 13 |
+
from .chart_output_to_chart_result import (
|
| 14 |
+
densepose_chart_predictor_output_to_result,
|
| 15 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 16 |
+
)
|
| 17 |
+
from .chart_output_hflip import densepose_chart_predictor_output_hflip
|
densepose/converters/base.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple, Type
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseConverter:
|
| 10 |
+
"""
|
| 11 |
+
Converter base class to be reused by various converters.
|
| 12 |
+
Converter allows one to convert data from various source types to a particular
|
| 13 |
+
destination type. Each source type needs to register its converter. The
|
| 14 |
+
registration for each source type is valid for all descendants of that type.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def register(cls, from_type: Type, converter: Any = None):
|
| 19 |
+
"""
|
| 20 |
+
Registers a converter for the specified type.
|
| 21 |
+
Can be used as a decorator (if converter is None), or called as a method.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
from_type (type): type to register the converter for;
|
| 25 |
+
all instances of this type will use the same converter
|
| 26 |
+
converter (callable): converter to be registered for the given
|
| 27 |
+
type; if None, this method is assumed to be a decorator for the converter
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
if converter is not None:
|
| 31 |
+
cls._do_register(from_type, converter)
|
| 32 |
+
|
| 33 |
+
def wrapper(converter: Any) -> Any:
|
| 34 |
+
cls._do_register(from_type, converter)
|
| 35 |
+
return converter
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def _do_register(cls, from_type: Type, converter: Any):
|
| 41 |
+
cls.registry[from_type] = converter # pyre-ignore[16]
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def _lookup_converter(cls, from_type: Type) -> Any:
|
| 45 |
+
"""
|
| 46 |
+
Perform recursive lookup for the given type
|
| 47 |
+
to find registered converter. If a converter was found for some base
|
| 48 |
+
class, it gets registered for this class to save on further lookups.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
from_type: type for which to find a converter
|
| 52 |
+
Return:
|
| 53 |
+
callable or None - registered converter or None
|
| 54 |
+
if no suitable entry was found in the registry
|
| 55 |
+
"""
|
| 56 |
+
if from_type in cls.registry: # pyre-ignore[16]
|
| 57 |
+
return cls.registry[from_type]
|
| 58 |
+
for base in from_type.__bases__:
|
| 59 |
+
converter = cls._lookup_converter(base)
|
| 60 |
+
if converter is not None:
|
| 61 |
+
cls._do_register(from_type, converter)
|
| 62 |
+
return converter
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def convert(cls, instance: Any, *args, **kwargs):
|
| 67 |
+
"""
|
| 68 |
+
Convert an instance to the destination type using some registered
|
| 69 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 70 |
+
for explicit registration for derived classes.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
instance: source instance to convert to the destination type
|
| 74 |
+
Return:
|
| 75 |
+
An instance of the destination type obtained from the source instance
|
| 76 |
+
Raises KeyError, if no suitable converter found
|
| 77 |
+
"""
|
| 78 |
+
instance_type = type(instance)
|
| 79 |
+
converter = cls._lookup_converter(instance_type)
|
| 80 |
+
if converter is None:
|
| 81 |
+
if cls.dst_type is None: # pyre-ignore[16]
|
| 82 |
+
output_type_str = "itself"
|
| 83 |
+
else:
|
| 84 |
+
output_type_str = cls.dst_type
|
| 85 |
+
raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
|
| 86 |
+
return converter(instance, *args, **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
IntTupleBox = Tuple[int, int, int, int]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
| 93 |
+
int_box = [0, 0, 0, 0]
|
| 94 |
+
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
| 95 |
+
return int_box[0], int_box[1], int_box[2], int_box[3]
|
densepose/converters/builtin.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
|
| 6 |
+
from . import (
|
| 7 |
+
HFlipConverter,
|
| 8 |
+
ToChartResultConverter,
|
| 9 |
+
ToChartResultConverterWithConfidences,
|
| 10 |
+
ToMaskConverter,
|
| 11 |
+
densepose_chart_predictor_output_hflip,
|
| 12 |
+
densepose_chart_predictor_output_to_result,
|
| 13 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 14 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 15 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
ToMaskConverter.register(
|
| 19 |
+
DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
|
| 20 |
+
)
|
| 21 |
+
ToMaskConverter.register(
|
| 22 |
+
DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
ToChartResultConverter.register(
|
| 26 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
ToChartResultConverterWithConfidences.register(
|
| 30 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
|
densepose/converters/chart_output_hflip.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from dataclasses import fields
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def densepose_chart_predictor_output_hflip(
|
| 11 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 12 |
+
transform_data: DensePoseTransformData,
|
| 13 |
+
) -> DensePoseChartPredictorOutput:
|
| 14 |
+
"""
|
| 15 |
+
Change to take into account a Horizontal flip.
|
| 16 |
+
"""
|
| 17 |
+
if len(densepose_predictor_output) > 0:
|
| 18 |
+
|
| 19 |
+
PredictorOutput = type(densepose_predictor_output)
|
| 20 |
+
output_dict = {}
|
| 21 |
+
|
| 22 |
+
for field in fields(densepose_predictor_output):
|
| 23 |
+
field_value = getattr(densepose_predictor_output, field.name)
|
| 24 |
+
# flip tensors
|
| 25 |
+
if isinstance(field_value, torch.Tensor):
|
| 26 |
+
setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
|
| 27 |
+
|
| 28 |
+
densepose_predictor_output = _flip_iuv_semantics_tensor(
|
| 29 |
+
densepose_predictor_output, transform_data
|
| 30 |
+
)
|
| 31 |
+
densepose_predictor_output = _flip_segm_semantics_tensor(
|
| 32 |
+
densepose_predictor_output, transform_data
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
for field in fields(densepose_predictor_output):
|
| 36 |
+
output_dict[field.name] = getattr(densepose_predictor_output, field.name)
|
| 37 |
+
|
| 38 |
+
return PredictorOutput(**output_dict)
|
| 39 |
+
else:
|
| 40 |
+
return densepose_predictor_output
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _flip_iuv_semantics_tensor(
|
| 44 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 45 |
+
dp_transform_data: DensePoseTransformData,
|
| 46 |
+
) -> DensePoseChartPredictorOutput:
|
| 47 |
+
point_label_symmetries = dp_transform_data.point_label_symmetries
|
| 48 |
+
uv_symmetries = dp_transform_data.uv_symmetries
|
| 49 |
+
|
| 50 |
+
N, C, H, W = densepose_predictor_output.u.shape
|
| 51 |
+
u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 52 |
+
v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 53 |
+
Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
|
| 54 |
+
None, :, None, None
|
| 55 |
+
].expand(N, C - 1, H, W)
|
| 56 |
+
densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
|
| 57 |
+
densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
|
| 58 |
+
|
| 59 |
+
for el in ["fine_segm", "u", "v"]:
|
| 60 |
+
densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
|
| 61 |
+
:, point_label_symmetries, :, :
|
| 62 |
+
]
|
| 63 |
+
return densepose_predictor_output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _flip_segm_semantics_tensor(
|
| 67 |
+
densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
|
| 68 |
+
):
|
| 69 |
+
if densepose_predictor_output.coarse_segm.shape[1] > 2:
|
| 70 |
+
densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
|
| 71 |
+
:, dp_transform_data.mask_label_symmetries, :, :
|
| 72 |
+
]
|
| 73 |
+
return densepose_predictor_output
|
densepose/converters/chart_output_to_chart_result.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures.boxes import Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from ..structures import (
|
| 12 |
+
DensePoseChartPredictorOutput,
|
| 13 |
+
DensePoseChartResult,
|
| 14 |
+
DensePoseChartResultWithConfidences,
|
| 15 |
+
)
|
| 16 |
+
from . import resample_fine_and_coarse_segm_to_bbox
|
| 17 |
+
from .base import IntTupleBox, make_int_box
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def resample_uv_tensors_to_bbox(
|
| 21 |
+
u: torch.Tensor,
|
| 22 |
+
v: torch.Tensor,
|
| 23 |
+
labels: torch.Tensor,
|
| 24 |
+
box_xywh_abs: IntTupleBox,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
u (tensor [1, C, H, W] of float): U coordinates
|
| 31 |
+
v (tensor [1, C, H, W] of float): V coordinates
|
| 32 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 33 |
+
outputs for the given bounding box
|
| 34 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 35 |
+
Return:
|
| 36 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 37 |
+
"""
|
| 38 |
+
x, y, w, h = box_xywh_abs
|
| 39 |
+
w = max(int(w), 1)
|
| 40 |
+
h = max(int(h), 1)
|
| 41 |
+
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
| 42 |
+
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
| 43 |
+
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
| 44 |
+
for part_id in range(1, u_bbox.size(1)):
|
| 45 |
+
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
| 46 |
+
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
| 47 |
+
return uv
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def resample_uv_to_bbox(
|
| 51 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 52 |
+
labels: torch.Tensor,
|
| 53 |
+
box_xywh_abs: IntTupleBox,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 60 |
+
output to be resampled
|
| 61 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 62 |
+
outputs for the given bounding box
|
| 63 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 64 |
+
Return:
|
| 65 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 66 |
+
"""
|
| 67 |
+
return resample_uv_tensors_to_bbox(
|
| 68 |
+
predictor_output.u,
|
| 69 |
+
predictor_output.v,
|
| 70 |
+
labels,
|
| 71 |
+
box_xywh_abs,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def densepose_chart_predictor_output_to_result(
|
| 76 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 77 |
+
) -> DensePoseChartResult:
|
| 78 |
+
"""
|
| 79 |
+
Convert densepose chart predictor outputs to results
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 83 |
+
output to be converted to results, must contain only 1 output
|
| 84 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 85 |
+
must contain only 1 bounding box
|
| 86 |
+
Return:
|
| 87 |
+
DensePose chart-based result (DensePoseChartResult)
|
| 88 |
+
"""
|
| 89 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 90 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 91 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 95 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 96 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 97 |
+
|
| 98 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 99 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 100 |
+
return DensePoseChartResult(labels=labels, uv=uv)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def resample_confidences_to_bbox(
|
| 104 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 105 |
+
labels: torch.Tensor,
|
| 106 |
+
box_xywh_abs: IntTupleBox,
|
| 107 |
+
) -> Dict[str, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Resamples confidences for the given bounding box
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 113 |
+
output to be resampled
|
| 114 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 115 |
+
outputs for the given bounding box
|
| 116 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 117 |
+
Return:
|
| 118 |
+
Resampled confidences - a dict of [H, W] tensors of float
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
x, y, w, h = box_xywh_abs
|
| 122 |
+
w = max(int(w), 1)
|
| 123 |
+
h = max(int(h), 1)
|
| 124 |
+
|
| 125 |
+
confidence_names = [
|
| 126 |
+
"sigma_1",
|
| 127 |
+
"sigma_2",
|
| 128 |
+
"kappa_u",
|
| 129 |
+
"kappa_v",
|
| 130 |
+
"fine_segm_confidence",
|
| 131 |
+
"coarse_segm_confidence",
|
| 132 |
+
]
|
| 133 |
+
confidence_results = {key: None for key in confidence_names}
|
| 134 |
+
confidence_names = [
|
| 135 |
+
key for key in confidence_names if getattr(predictor_output, key) is not None
|
| 136 |
+
]
|
| 137 |
+
confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
|
| 138 |
+
|
| 139 |
+
# assign data from channels that correspond to the labels
|
| 140 |
+
for key in confidence_names:
|
| 141 |
+
resampled_confidence = F.interpolate(
|
| 142 |
+
getattr(predictor_output, key),
|
| 143 |
+
(h, w),
|
| 144 |
+
mode="bilinear",
|
| 145 |
+
align_corners=False,
|
| 146 |
+
)
|
| 147 |
+
result = confidence_base.clone()
|
| 148 |
+
for part_id in range(1, predictor_output.u.size(1)):
|
| 149 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 150 |
+
# confidence is not part-based, don't try to fill it part by part
|
| 151 |
+
continue
|
| 152 |
+
result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
|
| 153 |
+
|
| 154 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 155 |
+
# confidence is not part-based, fill the data with the first channel
|
| 156 |
+
# (targeted for segmentation confidences that have only 1 channel)
|
| 157 |
+
result = resampled_confidence[0, 0]
|
| 158 |
+
|
| 159 |
+
confidence_results[key] = result
|
| 160 |
+
|
| 161 |
+
return confidence_results # pyre-ignore[7]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def densepose_chart_predictor_output_to_result_with_confidences(
|
| 165 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 166 |
+
) -> DensePoseChartResultWithConfidences:
|
| 167 |
+
"""
|
| 168 |
+
Convert densepose chart predictor outputs to results
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 172 |
+
output with confidences to be converted to results, must contain only 1 output
|
| 173 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 174 |
+
must contain only 1 bounding box
|
| 175 |
+
Return:
|
| 176 |
+
DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
|
| 177 |
+
"""
|
| 178 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 179 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 180 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 184 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 185 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 186 |
+
|
| 187 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 188 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 189 |
+
confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
|
| 190 |
+
return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
|
densepose/converters/hflip.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from .base import BaseConverter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HFlipConverter(BaseConverter):
|
| 11 |
+
"""
|
| 12 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 13 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
registry = {}
|
| 17 |
+
dst_type = None
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 21 |
+
# inconsistently.
|
| 22 |
+
def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Performs an horizontal flip on DensePose predictor outputs.
|
| 25 |
+
Does recursive lookup for base classes, so there's no need
|
| 26 |
+
for explicit registration for derived classes.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
predictor_outputs: DensePose predictor output to be converted to BitMasks
|
| 30 |
+
transform_data: Anything useful for the flip
|
| 31 |
+
Return:
|
| 32 |
+
An instance of the same type as predictor_outputs
|
| 33 |
+
"""
|
| 34 |
+
return super(HFlipConverter, cls).convert(
|
| 35 |
+
predictor_outputs, transform_data, *args, **kwargs
|
| 36 |
+
)
|
densepose/converters/segm_to_mask.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from .base import IntTupleBox, make_int_box
|
| 12 |
+
from .to_mask import ImageSizeType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
|
| 16 |
+
"""
|
| 17 |
+
Resample coarse segmentation tensor to the given
|
| 18 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 22 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 23 |
+
corner coordinates, width (W) and height (H)
|
| 24 |
+
Return:
|
| 25 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 26 |
+
"""
|
| 27 |
+
x, y, w, h = box_xywh_abs
|
| 28 |
+
w = max(int(w), 1)
|
| 29 |
+
h = max(int(h), 1)
|
| 30 |
+
labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 31 |
+
return labels
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 35 |
+
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Resample fine and coarse segmentation tensors to the given
|
| 39 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
| 43 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 44 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 45 |
+
corner coordinates, width (W) and height (H)
|
| 46 |
+
Return:
|
| 47 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 48 |
+
"""
|
| 49 |
+
x, y, w, h = box_xywh_abs
|
| 50 |
+
w = max(int(w), 1)
|
| 51 |
+
h = max(int(h), 1)
|
| 52 |
+
# coarse segmentation
|
| 53 |
+
coarse_segm_bbox = F.interpolate(
|
| 54 |
+
coarse_segm,
|
| 55 |
+
(h, w),
|
| 56 |
+
mode="bilinear",
|
| 57 |
+
align_corners=False,
|
| 58 |
+
).argmax(dim=1)
|
| 59 |
+
# combined coarse and fine segmentation
|
| 60 |
+
labels = (
|
| 61 |
+
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 62 |
+
* (coarse_segm_bbox > 0).long()
|
| 63 |
+
)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
|
| 68 |
+
"""
|
| 69 |
+
Resample fine and coarse segmentation outputs from a predictor to the given
|
| 70 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
predictor_output: DensePose predictor output that contains segmentation
|
| 74 |
+
results to be resampled
|
| 75 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 76 |
+
corner coordinates, width (W) and height (H)
|
| 77 |
+
Return:
|
| 78 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 79 |
+
"""
|
| 80 |
+
return resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 81 |
+
predictor_output.fine_segm,
|
| 82 |
+
predictor_output.coarse_segm,
|
| 83 |
+
box_xywh_abs,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def predictor_output_with_coarse_segm_to_mask(
|
| 88 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 89 |
+
) -> BitMasks:
|
| 90 |
+
"""
|
| 91 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 92 |
+
Assumes that predictor output has the following attributes:
|
| 93 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 94 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 95 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 99 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 100 |
+
predictor outputs
|
| 101 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 102 |
+
Return:
|
| 103 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 104 |
+
a mask of the size of the image for each instance
|
| 105 |
+
"""
|
| 106 |
+
H, W = image_size_hw
|
| 107 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 108 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 109 |
+
N = len(boxes_xywh_abs)
|
| 110 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 111 |
+
for i in range(len(boxes_xywh_abs)):
|
| 112 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 113 |
+
box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
|
| 114 |
+
x, y, w, h = box_xywh
|
| 115 |
+
masks[i, y : y + h, x : x + w] = box_mask
|
| 116 |
+
|
| 117 |
+
return BitMasks(masks)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def predictor_output_with_fine_and_coarse_segm_to_mask(
|
| 121 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 122 |
+
) -> BitMasks:
|
| 123 |
+
"""
|
| 124 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 125 |
+
Assumes that predictor output has the following attributes:
|
| 126 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 127 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 128 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 129 |
+
- fine_segm (tensor of size [N, C, H, W]): fine segmentation
|
| 130 |
+
unnormalized scores for N instances; C is the number of fine
|
| 131 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 135 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 136 |
+
predictor outputs
|
| 137 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 138 |
+
Return:
|
| 139 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 140 |
+
a mask of the size of the image for each instance
|
| 141 |
+
"""
|
| 142 |
+
H, W = image_size_hw
|
| 143 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 144 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 145 |
+
N = len(boxes_xywh_abs)
|
| 146 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 147 |
+
for i in range(len(boxes_xywh_abs)):
|
| 148 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 149 |
+
labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
|
| 150 |
+
x, y, w, h = box_xywh
|
| 151 |
+
masks[i, y : y + h, x : x + w] = labels_i > 0
|
| 152 |
+
return BitMasks(masks)
|
densepose/converters/to_chart_result.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import Boxes
|
| 8 |
+
|
| 9 |
+
from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
|
| 10 |
+
from .base import BaseConverter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ToChartResultConverter(BaseConverter):
|
| 14 |
+
"""
|
| 15 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 16 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
registry = {}
|
| 20 |
+
dst_type = DensePoseChartResult
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 24 |
+
# inconsistently.
|
| 25 |
+
def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
|
| 26 |
+
"""
|
| 27 |
+
Convert DensePose predictor outputs to DensePoseResult using some registered
|
| 28 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 29 |
+
for explicit registration for derived classes.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 33 |
+
converted to BitMasks
|
| 34 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 35 |
+
predictor outputs
|
| 36 |
+
Return:
|
| 37 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 38 |
+
"""
|
| 39 |
+
return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ToChartResultConverterWithConfidences(BaseConverter):
|
| 43 |
+
"""
|
| 44 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 45 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
registry = {}
|
| 49 |
+
dst_type = DensePoseChartResultWithConfidences
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 53 |
+
# inconsistently.
|
| 54 |
+
def convert(
|
| 55 |
+
cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
|
| 56 |
+
) -> DensePoseChartResultWithConfidences:
|
| 57 |
+
"""
|
| 58 |
+
Convert DensePose predictor outputs to DensePoseResult with confidences
|
| 59 |
+
using some registered converter. Does recursive lookup for base classes,
|
| 60 |
+
so there's no need for explicit registration for derived classes.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
densepose_predictor_outputs: DensePose predictor output with confidences
|
| 64 |
+
to be converted to BitMasks
|
| 65 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 66 |
+
predictor outputs
|
| 67 |
+
Return:
|
| 68 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 69 |
+
"""
|
| 70 |
+
return super(ToChartResultConverterWithConfidences, cls).convert(
|
| 71 |
+
predictor_outputs, boxes, *args, **kwargs
|
| 72 |
+
)
|
densepose/converters/to_mask.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import BitMasks, Boxes
|
| 8 |
+
|
| 9 |
+
from .base import BaseConverter
|
| 10 |
+
|
| 11 |
+
ImageSizeType = Tuple[int, int]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ToMaskConverter(BaseConverter):
|
| 15 |
+
"""
|
| 16 |
+
Converts various DensePose predictor outputs to masks
|
| 17 |
+
in bit mask format (see `BitMasks`). Each DensePose predictor output type
|
| 18 |
+
has to register its convertion strategy.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
registry = {}
|
| 22 |
+
dst_type = BitMasks
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 26 |
+
# inconsistently.
|
| 27 |
+
def convert(
|
| 28 |
+
cls,
|
| 29 |
+
densepose_predictor_outputs: Any,
|
| 30 |
+
boxes: Boxes,
|
| 31 |
+
image_size_hw: ImageSizeType,
|
| 32 |
+
*args,
|
| 33 |
+
**kwargs
|
| 34 |
+
) -> BitMasks:
|
| 35 |
+
"""
|
| 36 |
+
Convert DensePose predictor outputs to BitMasks using some registered
|
| 37 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 38 |
+
for explicit registration for derived classes.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 42 |
+
converted to BitMasks
|
| 43 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 44 |
+
predictor outputs
|
| 45 |
+
image_size_hw (tuple [int, int]): image height and width
|
| 46 |
+
Return:
|
| 47 |
+
An instance of `BitMasks`. If no suitable converter was found, raises KeyError
|
| 48 |
+
"""
|
| 49 |
+
return super(ToMaskConverter, cls).convert(
|
| 50 |
+
densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
|
| 51 |
+
)
|
densepose/data/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .meshes import builtin
|
| 6 |
+
from .build import (
|
| 7 |
+
build_detection_test_loader,
|
| 8 |
+
build_detection_train_loader,
|
| 9 |
+
build_combined_loader,
|
| 10 |
+
build_frame_selector,
|
| 11 |
+
build_inference_based_loaders,
|
| 12 |
+
has_inference_based_loaders,
|
| 13 |
+
BootstrapDatasetFactoryCatalog,
|
| 14 |
+
)
|
| 15 |
+
from .combined_loader import CombinedDataLoader
|
| 16 |
+
from .dataset_mapper import DatasetMapper
|
| 17 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
| 18 |
+
from .image_list_dataset import ImageListDataset
|
| 19 |
+
from .utils import is_relative_local_path, maybe_prepend_base_path
|
| 20 |
+
|
| 21 |
+
# ensure the builtin datasets are registered
|
| 22 |
+
from . import datasets
|
| 23 |
+
|
| 24 |
+
# ensure the bootstrap datasets builders are registered
|
| 25 |
+
from . import build
|
| 26 |
+
|
| 27 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
densepose/data/build.py
ADDED
|
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from collections import UserDict, defaultdict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data.dataset import Dataset
|
| 13 |
+
|
| 14 |
+
from detectron2.config import CfgNode
|
| 15 |
+
from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
|
| 16 |
+
from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
|
| 17 |
+
from detectron2.data.build import (
|
| 18 |
+
load_proposals_into_dataset,
|
| 19 |
+
print_instances_class_histogram,
|
| 20 |
+
trivial_batch_collator,
|
| 21 |
+
worker_init_reset_seed,
|
| 22 |
+
)
|
| 23 |
+
from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
|
| 24 |
+
from detectron2.data.samplers import TrainingSampler
|
| 25 |
+
from detectron2.utils.comm import get_world_size
|
| 26 |
+
|
| 27 |
+
from densepose.config import get_bootstrap_dataset_config
|
| 28 |
+
from densepose.modeling import build_densepose_embedder
|
| 29 |
+
|
| 30 |
+
from .combined_loader import CombinedDataLoader, Loader
|
| 31 |
+
from .dataset_mapper import DatasetMapper
|
| 32 |
+
from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
|
| 33 |
+
from .datasets.dataset_type import DatasetType
|
| 34 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
| 35 |
+
from .samplers import (
|
| 36 |
+
DensePoseConfidenceBasedSampler,
|
| 37 |
+
DensePoseCSEConfidenceBasedSampler,
|
| 38 |
+
DensePoseCSEUniformSampler,
|
| 39 |
+
DensePoseUniformSampler,
|
| 40 |
+
MaskFromDensePoseSampler,
|
| 41 |
+
PredictionToGroundTruthSampler,
|
| 42 |
+
)
|
| 43 |
+
from .transform import ImageResizeTransform
|
| 44 |
+
from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
|
| 45 |
+
from .video import (
|
| 46 |
+
FirstKFramesSelector,
|
| 47 |
+
FrameSelectionStrategy,
|
| 48 |
+
LastKFramesSelector,
|
| 49 |
+
RandomKFramesSelector,
|
| 50 |
+
VideoKeyframeDataset,
|
| 51 |
+
video_list_from_file,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
__all__ = ["build_detection_train_loader", "build_detection_test_loader"]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
Instance = Dict[str, Any]
|
| 58 |
+
InstancePredicate = Callable[[Instance], bool]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _compute_num_images_per_worker(cfg: CfgNode) -> int:
|
| 62 |
+
num_workers = get_world_size()
|
| 63 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
| 64 |
+
assert (
|
| 65 |
+
images_per_batch % num_workers == 0
|
| 66 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
| 67 |
+
images_per_batch, num_workers
|
| 68 |
+
)
|
| 69 |
+
assert (
|
| 70 |
+
images_per_batch >= num_workers
|
| 71 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
| 72 |
+
images_per_batch, num_workers
|
| 73 |
+
)
|
| 74 |
+
images_per_worker = images_per_batch // num_workers
|
| 75 |
+
return images_per_worker
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
|
| 79 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 80 |
+
for dataset_dict in dataset_dicts:
|
| 81 |
+
for ann in dataset_dict["annotations"]:
|
| 82 |
+
ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class _DatasetCategory:
|
| 87 |
+
"""
|
| 88 |
+
Class representing category data in a dataset:
|
| 89 |
+
- id: category ID, as specified in the dataset annotations file
|
| 90 |
+
- name: category name, as specified in the dataset annotations file
|
| 91 |
+
- mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
|
| 92 |
+
- mapped_name: category name after applying category maps
|
| 93 |
+
- dataset_name: dataset in which the category is defined
|
| 94 |
+
|
| 95 |
+
For example, when training models in a class-agnostic manner, one could take LVIS 1.0
|
| 96 |
+
dataset and map the animal categories to the same category as human data from COCO:
|
| 97 |
+
id = 225
|
| 98 |
+
name = "cat"
|
| 99 |
+
mapped_id = 1
|
| 100 |
+
mapped_name = "person"
|
| 101 |
+
dataset_name = "lvis_v1_animals_dp_train"
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
id: int
|
| 105 |
+
name: str
|
| 106 |
+
mapped_id: int
|
| 107 |
+
mapped_name: str
|
| 108 |
+
dataset_name: str
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
_MergedCategoriesT = Dict[int, List[_DatasetCategory]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _add_category_id_to_contiguous_id_maps_to_metadata(
|
| 115 |
+
merged_categories: _MergedCategoriesT,
|
| 116 |
+
) -> None:
|
| 117 |
+
merged_categories_per_dataset = {}
|
| 118 |
+
for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
|
| 119 |
+
for cat in merged_categories[cat_id]:
|
| 120 |
+
if cat.dataset_name not in merged_categories_per_dataset:
|
| 121 |
+
merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
|
| 122 |
+
merged_categories_per_dataset[cat.dataset_name][cat_id].append(
|
| 123 |
+
(
|
| 124 |
+
contiguous_cat_id,
|
| 125 |
+
cat,
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
logger = logging.getLogger(__name__)
|
| 130 |
+
for dataset_name, merged_categories in merged_categories_per_dataset.items():
|
| 131 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 132 |
+
if not hasattr(meta, "thing_classes"):
|
| 133 |
+
meta.thing_classes = []
|
| 134 |
+
meta.thing_dataset_id_to_contiguous_id = {}
|
| 135 |
+
meta.thing_dataset_id_to_merged_id = {}
|
| 136 |
+
else:
|
| 137 |
+
meta.thing_classes.clear()
|
| 138 |
+
meta.thing_dataset_id_to_contiguous_id.clear()
|
| 139 |
+
meta.thing_dataset_id_to_merged_id.clear()
|
| 140 |
+
logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
|
| 141 |
+
for _cat_id, categories in sorted(merged_categories.items()):
|
| 142 |
+
added_to_thing_classes = False
|
| 143 |
+
for contiguous_cat_id, cat in categories:
|
| 144 |
+
if not added_to_thing_classes:
|
| 145 |
+
meta.thing_classes.append(cat.mapped_name)
|
| 146 |
+
added_to_thing_classes = True
|
| 147 |
+
meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
|
| 148 |
+
meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
|
| 149 |
+
logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 153 |
+
def has_annotations(instance: Instance) -> bool:
|
| 154 |
+
return "annotations" in instance
|
| 155 |
+
|
| 156 |
+
def has_only_crowd_anotations(instance: Instance) -> bool:
|
| 157 |
+
for ann in instance["annotations"]:
|
| 158 |
+
if ann.get("is_crowd", 0) == 0:
|
| 159 |
+
return False
|
| 160 |
+
return True
|
| 161 |
+
|
| 162 |
+
def general_keep_instance_predicate(instance: Instance) -> bool:
|
| 163 |
+
return has_annotations(instance) and not has_only_crowd_anotations(instance)
|
| 164 |
+
|
| 165 |
+
if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
|
| 166 |
+
return None
|
| 167 |
+
return general_keep_instance_predicate
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 171 |
+
|
| 172 |
+
min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 173 |
+
|
| 174 |
+
def has_sufficient_num_keypoints(instance: Instance) -> bool:
|
| 175 |
+
num_kpts = sum(
|
| 176 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
| 177 |
+
for ann in instance["annotations"]
|
| 178 |
+
if "keypoints" in ann
|
| 179 |
+
)
|
| 180 |
+
return num_kpts >= min_num_keypoints
|
| 181 |
+
|
| 182 |
+
if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
|
| 183 |
+
return has_sufficient_num_keypoints
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 188 |
+
if not cfg.MODEL.MASK_ON:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
def has_mask_annotations(instance: Instance) -> bool:
|
| 192 |
+
return any("segmentation" in ann for ann in instance["annotations"])
|
| 193 |
+
|
| 194 |
+
return has_mask_annotations
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 198 |
+
if not cfg.MODEL.DENSEPOSE_ON:
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
|
| 202 |
+
|
| 203 |
+
def has_densepose_annotations(instance: Instance) -> bool:
|
| 204 |
+
for ann in instance["annotations"]:
|
| 205 |
+
if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
|
| 206 |
+
key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
|
| 207 |
+
):
|
| 208 |
+
return True
|
| 209 |
+
if use_masks and "segmentation" in ann:
|
| 210 |
+
return True
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
return has_densepose_annotations
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 217 |
+
specific_predicate_creators = [
|
| 218 |
+
_maybe_create_keypoints_keep_instance_predicate,
|
| 219 |
+
_maybe_create_mask_keep_instance_predicate,
|
| 220 |
+
_maybe_create_densepose_keep_instance_predicate,
|
| 221 |
+
]
|
| 222 |
+
predicates = [creator(cfg) for creator in specific_predicate_creators]
|
| 223 |
+
predicates = [p for p in predicates if p is not None]
|
| 224 |
+
if not predicates:
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
def combined_predicate(instance: Instance) -> bool:
|
| 228 |
+
return any(p(instance) for p in predicates)
|
| 229 |
+
|
| 230 |
+
return combined_predicate
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _get_train_keep_instance_predicate(cfg: CfgNode):
|
| 234 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
| 235 |
+
combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
|
| 236 |
+
|
| 237 |
+
def combined_general_specific_keep_predicate(instance: Instance) -> bool:
|
| 238 |
+
return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
|
| 239 |
+
|
| 240 |
+
if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
|
| 241 |
+
return None
|
| 242 |
+
if general_keep_predicate is None:
|
| 243 |
+
return combined_specific_keep_predicate
|
| 244 |
+
if combined_specific_keep_predicate is None:
|
| 245 |
+
return general_keep_predicate
|
| 246 |
+
return combined_general_specific_keep_predicate
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _get_test_keep_instance_predicate(cfg: CfgNode):
|
| 250 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
| 251 |
+
return general_keep_predicate
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _maybe_filter_and_map_categories(
|
| 255 |
+
dataset_name: str, dataset_dicts: List[Instance]
|
| 256 |
+
) -> List[Instance]:
|
| 257 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 258 |
+
category_id_map = meta.thing_dataset_id_to_contiguous_id
|
| 259 |
+
filtered_dataset_dicts = []
|
| 260 |
+
for dataset_dict in dataset_dicts:
|
| 261 |
+
anns = []
|
| 262 |
+
for ann in dataset_dict["annotations"]:
|
| 263 |
+
cat_id = ann["category_id"]
|
| 264 |
+
if cat_id not in category_id_map:
|
| 265 |
+
continue
|
| 266 |
+
ann["category_id"] = category_id_map[cat_id]
|
| 267 |
+
anns.append(ann)
|
| 268 |
+
dataset_dict["annotations"] = anns
|
| 269 |
+
filtered_dataset_dicts.append(dataset_dict)
|
| 270 |
+
return filtered_dataset_dicts
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
|
| 274 |
+
for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
|
| 275 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 276 |
+
meta.whitelisted_categories = whitelisted_cat_ids
|
| 277 |
+
logger = logging.getLogger(__name__)
|
| 278 |
+
logger.info(
|
| 279 |
+
"Whitelisted categories for dataset {}: {}".format(
|
| 280 |
+
dataset_name, meta.whitelisted_categories
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
|
| 286 |
+
for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
|
| 287 |
+
category_map = {
|
| 288 |
+
int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
|
| 289 |
+
}
|
| 290 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 291 |
+
meta.category_map = category_map
|
| 292 |
+
logger = logging.getLogger(__name__)
|
| 293 |
+
logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
|
| 297 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 298 |
+
meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
|
| 299 |
+
meta.categories = dataset_cfg.CATEGORIES
|
| 300 |
+
meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
|
| 301 |
+
logger = logging.getLogger(__name__)
|
| 302 |
+
logger.info(
|
| 303 |
+
"Category to class mapping for dataset {}: {}".format(
|
| 304 |
+
dataset_name, meta.category_to_class_mapping
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
|
| 310 |
+
for dataset_name in dataset_names:
|
| 311 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 312 |
+
if not hasattr(meta, "class_to_mesh_name"):
|
| 313 |
+
meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
|
| 317 |
+
merged_categories = defaultdict(list)
|
| 318 |
+
category_names = {}
|
| 319 |
+
for dataset_name in dataset_names:
|
| 320 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 321 |
+
whitelisted_categories = meta.get("whitelisted_categories")
|
| 322 |
+
category_map = meta.get("category_map", {})
|
| 323 |
+
cat_ids = (
|
| 324 |
+
whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
|
| 325 |
+
)
|
| 326 |
+
for cat_id in cat_ids:
|
| 327 |
+
cat_name = meta.categories[cat_id]
|
| 328 |
+
cat_id_mapped = category_map.get(cat_id, cat_id)
|
| 329 |
+
if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
|
| 330 |
+
category_names[cat_id] = cat_name
|
| 331 |
+
else:
|
| 332 |
+
category_names[cat_id] = str(cat_id_mapped)
|
| 333 |
+
# assign temporary mapped category name, this name can be changed
|
| 334 |
+
# during the second pass, since mapped ID can correspond to a category
|
| 335 |
+
# from a different dataset
|
| 336 |
+
cat_name_mapped = meta.categories[cat_id_mapped]
|
| 337 |
+
merged_categories[cat_id_mapped].append(
|
| 338 |
+
_DatasetCategory(
|
| 339 |
+
id=cat_id,
|
| 340 |
+
name=cat_name,
|
| 341 |
+
mapped_id=cat_id_mapped,
|
| 342 |
+
mapped_name=cat_name_mapped,
|
| 343 |
+
dataset_name=dataset_name,
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
# second pass to assign proper mapped category names
|
| 347 |
+
for cat_id, categories in merged_categories.items():
|
| 348 |
+
for cat in categories:
|
| 349 |
+
if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
|
| 350 |
+
cat.mapped_name = category_names[cat_id]
|
| 351 |
+
|
| 352 |
+
return merged_categories
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
|
| 356 |
+
logger = logging.getLogger(__name__)
|
| 357 |
+
for cat_id in merged_categories:
|
| 358 |
+
merged_categories_i = merged_categories[cat_id]
|
| 359 |
+
first_cat_name = merged_categories_i[0].name
|
| 360 |
+
if len(merged_categories_i) > 1 and not all(
|
| 361 |
+
cat.name == first_cat_name for cat in merged_categories_i[1:]
|
| 362 |
+
):
|
| 363 |
+
cat_summary_str = ", ".join(
|
| 364 |
+
[f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
|
| 365 |
+
)
|
| 366 |
+
logger.warning(
|
| 367 |
+
f"Merged category {cat_id} corresponds to the following categories: "
|
| 368 |
+
f"{cat_summary_str}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def combine_detection_dataset_dicts(
|
| 373 |
+
dataset_names: Collection[str],
|
| 374 |
+
keep_instance_predicate: Optional[InstancePredicate] = None,
|
| 375 |
+
proposal_files: Optional[Collection[str]] = None,
|
| 376 |
+
) -> List[Instance]:
|
| 377 |
+
"""
|
| 378 |
+
Load and prepare dataset dicts for training / testing
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
dataset_names (Collection[str]): a list of dataset names
|
| 382 |
+
keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
|
| 383 |
+
applied to instance dicts which defines whether to keep the instance
|
| 384 |
+
proposal_files (Collection[str]): if given, a list of object proposal files
|
| 385 |
+
that match each dataset in `dataset_names`.
|
| 386 |
+
"""
|
| 387 |
+
assert len(dataset_names)
|
| 388 |
+
if proposal_files is None:
|
| 389 |
+
proposal_files = [None] * len(dataset_names)
|
| 390 |
+
assert len(dataset_names) == len(proposal_files)
|
| 391 |
+
# load datasets and metadata
|
| 392 |
+
dataset_name_to_dicts = {}
|
| 393 |
+
for dataset_name in dataset_names:
|
| 394 |
+
dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
|
| 395 |
+
assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
|
| 396 |
+
# merge categories, requires category metadata to be loaded
|
| 397 |
+
# cat_id -> [(orig_cat_id, cat_name, dataset_name)]
|
| 398 |
+
merged_categories = _merge_categories(dataset_names)
|
| 399 |
+
_warn_if_merged_different_categories(merged_categories)
|
| 400 |
+
merged_category_names = [
|
| 401 |
+
merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
|
| 402 |
+
]
|
| 403 |
+
# map to contiguous category IDs
|
| 404 |
+
_add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
|
| 405 |
+
# load annotations and dataset metadata
|
| 406 |
+
for dataset_name, proposal_file in zip(dataset_names, proposal_files):
|
| 407 |
+
dataset_dicts = dataset_name_to_dicts[dataset_name]
|
| 408 |
+
assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
|
| 409 |
+
if proposal_file is not None:
|
| 410 |
+
dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
|
| 411 |
+
dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
|
| 412 |
+
print_instances_class_histogram(dataset_dicts, merged_category_names)
|
| 413 |
+
dataset_name_to_dicts[dataset_name] = dataset_dicts
|
| 414 |
+
|
| 415 |
+
if keep_instance_predicate is not None:
|
| 416 |
+
all_datasets_dicts_plain = [
|
| 417 |
+
d
|
| 418 |
+
for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
| 419 |
+
if keep_instance_predicate(d)
|
| 420 |
+
]
|
| 421 |
+
else:
|
| 422 |
+
all_datasets_dicts_plain = list(
|
| 423 |
+
itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
| 424 |
+
)
|
| 425 |
+
return all_datasets_dicts_plain
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def build_detection_train_loader(cfg: CfgNode, mapper=None):
|
| 429 |
+
"""
|
| 430 |
+
A data loader is created in a way similar to that of Detectron2.
|
| 431 |
+
The main differences are:
|
| 432 |
+
- it allows to combine datasets with different but compatible object category sets
|
| 433 |
+
|
| 434 |
+
The data loader is created by the following steps:
|
| 435 |
+
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
|
| 436 |
+
2. Start workers to work on the dicts. Each worker will:
|
| 437 |
+
* Map each metadata dict into another format to be consumed by the model.
|
| 438 |
+
* Batch them by simply putting dicts into a list.
|
| 439 |
+
The batched ``list[mapped_dict]`` is what this dataloader will return.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
cfg (CfgNode): the config
|
| 443 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 444 |
+
returns the format to be consumed by the model.
|
| 445 |
+
By default it will be `DatasetMapper(cfg, True)`.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
an infinite iterator of training data
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
_add_category_whitelists_to_metadata(cfg)
|
| 452 |
+
_add_category_maps_to_metadata(cfg)
|
| 453 |
+
_maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
|
| 454 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
| 455 |
+
cfg.DATASETS.TRAIN,
|
| 456 |
+
keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
|
| 457 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 458 |
+
)
|
| 459 |
+
if mapper is None:
|
| 460 |
+
mapper = DatasetMapper(cfg, True)
|
| 461 |
+
return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def build_detection_test_loader(cfg, dataset_name, mapper=None):
|
| 465 |
+
"""
|
| 466 |
+
Similar to `build_detection_train_loader`.
|
| 467 |
+
But this function uses the given `dataset_name` argument (instead of the names in cfg),
|
| 468 |
+
and uses batch size 1.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
cfg: a detectron2 CfgNode
|
| 472 |
+
dataset_name (str): a name of the dataset that's available in the DatasetCatalog
|
| 473 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
| 474 |
+
and returns the format to be consumed by the model.
|
| 475 |
+
By default it will be `DatasetMapper(cfg, False)`.
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 479 |
+
dataset, with test-time transformation and batching.
|
| 480 |
+
"""
|
| 481 |
+
_add_category_whitelists_to_metadata(cfg)
|
| 482 |
+
_add_category_maps_to_metadata(cfg)
|
| 483 |
+
_maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
|
| 484 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
| 485 |
+
[dataset_name],
|
| 486 |
+
keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
|
| 487 |
+
proposal_files=(
|
| 488 |
+
[cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
|
| 489 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
| 490 |
+
else None
|
| 491 |
+
),
|
| 492 |
+
)
|
| 493 |
+
sampler = None
|
| 494 |
+
if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
|
| 495 |
+
sampler = torch.utils.data.SequentialSampler(dataset_dicts)
|
| 496 |
+
if mapper is None:
|
| 497 |
+
mapper = DatasetMapper(cfg, False)
|
| 498 |
+
return d2_build_detection_test_loader(
|
| 499 |
+
dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def build_frame_selector(cfg: CfgNode):
|
| 504 |
+
strategy = FrameSelectionStrategy(cfg.STRATEGY)
|
| 505 |
+
if strategy == FrameSelectionStrategy.RANDOM_K:
|
| 506 |
+
frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
|
| 507 |
+
elif strategy == FrameSelectionStrategy.FIRST_K:
|
| 508 |
+
frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
|
| 509 |
+
elif strategy == FrameSelectionStrategy.LAST_K:
|
| 510 |
+
frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
|
| 511 |
+
elif strategy == FrameSelectionStrategy.ALL:
|
| 512 |
+
frame_selector = None
|
| 513 |
+
# pyre-fixme[61]: `frame_selector` may not be initialized here.
|
| 514 |
+
return frame_selector
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def build_transform(cfg: CfgNode, data_type: str):
|
| 518 |
+
if cfg.TYPE == "resize":
|
| 519 |
+
if data_type == "image":
|
| 520 |
+
return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
|
| 521 |
+
raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
|
| 525 |
+
images_per_worker = _compute_num_images_per_worker(cfg)
|
| 526 |
+
return CombinedDataLoader(loaders, images_per_worker, ratios)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
|
| 530 |
+
"""
|
| 531 |
+
Build dataset that provides data to bootstrap on
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
dataset_name (str): Name of the dataset, needs to have associated metadata
|
| 535 |
+
to load the data
|
| 536 |
+
cfg (CfgNode): bootstrapping config
|
| 537 |
+
Returns:
|
| 538 |
+
Sequence[Tensor] - dataset that provides image batches, Tensors of size
|
| 539 |
+
[N, C, H, W] of type float32
|
| 540 |
+
"""
|
| 541 |
+
logger = logging.getLogger(__name__)
|
| 542 |
+
_add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
|
| 543 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 544 |
+
factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
|
| 545 |
+
dataset = None
|
| 546 |
+
if factory is not None:
|
| 547 |
+
dataset = factory(meta, cfg)
|
| 548 |
+
if dataset is None:
|
| 549 |
+
logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
|
| 550 |
+
return dataset
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
|
| 554 |
+
if sampler_cfg.TYPE == "densepose_uniform":
|
| 555 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 556 |
+
# transform densepose pred -> gt
|
| 557 |
+
data_sampler.register_sampler(
|
| 558 |
+
"pred_densepose",
|
| 559 |
+
"gt_densepose",
|
| 560 |
+
DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
|
| 561 |
+
)
|
| 562 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 563 |
+
return data_sampler
|
| 564 |
+
elif sampler_cfg.TYPE == "densepose_UV_confidence":
|
| 565 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 566 |
+
# transform densepose pred -> gt
|
| 567 |
+
data_sampler.register_sampler(
|
| 568 |
+
"pred_densepose",
|
| 569 |
+
"gt_densepose",
|
| 570 |
+
DensePoseConfidenceBasedSampler(
|
| 571 |
+
confidence_channel="sigma_2",
|
| 572 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 573 |
+
search_proportion=0.5,
|
| 574 |
+
),
|
| 575 |
+
)
|
| 576 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 577 |
+
return data_sampler
|
| 578 |
+
elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
|
| 579 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 580 |
+
# transform densepose pred -> gt
|
| 581 |
+
data_sampler.register_sampler(
|
| 582 |
+
"pred_densepose",
|
| 583 |
+
"gt_densepose",
|
| 584 |
+
DensePoseConfidenceBasedSampler(
|
| 585 |
+
confidence_channel="fine_segm_confidence",
|
| 586 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 587 |
+
search_proportion=0.5,
|
| 588 |
+
),
|
| 589 |
+
)
|
| 590 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 591 |
+
return data_sampler
|
| 592 |
+
elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
|
| 593 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 594 |
+
# transform densepose pred -> gt
|
| 595 |
+
data_sampler.register_sampler(
|
| 596 |
+
"pred_densepose",
|
| 597 |
+
"gt_densepose",
|
| 598 |
+
DensePoseConfidenceBasedSampler(
|
| 599 |
+
confidence_channel="coarse_segm_confidence",
|
| 600 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 601 |
+
search_proportion=0.5,
|
| 602 |
+
),
|
| 603 |
+
)
|
| 604 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 605 |
+
return data_sampler
|
| 606 |
+
elif sampler_cfg.TYPE == "densepose_cse_uniform":
|
| 607 |
+
assert embedder is not None
|
| 608 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 609 |
+
# transform densepose pred -> gt
|
| 610 |
+
data_sampler.register_sampler(
|
| 611 |
+
"pred_densepose",
|
| 612 |
+
"gt_densepose",
|
| 613 |
+
DensePoseCSEUniformSampler(
|
| 614 |
+
cfg=cfg,
|
| 615 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
| 616 |
+
embedder=embedder,
|
| 617 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 618 |
+
),
|
| 619 |
+
)
|
| 620 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 621 |
+
return data_sampler
|
| 622 |
+
elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
|
| 623 |
+
assert embedder is not None
|
| 624 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 625 |
+
# transform densepose pred -> gt
|
| 626 |
+
data_sampler.register_sampler(
|
| 627 |
+
"pred_densepose",
|
| 628 |
+
"gt_densepose",
|
| 629 |
+
DensePoseCSEConfidenceBasedSampler(
|
| 630 |
+
cfg=cfg,
|
| 631 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
| 632 |
+
embedder=embedder,
|
| 633 |
+
confidence_channel="coarse_segm_confidence",
|
| 634 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 635 |
+
search_proportion=0.5,
|
| 636 |
+
),
|
| 637 |
+
)
|
| 638 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 639 |
+
return data_sampler
|
| 640 |
+
|
| 641 |
+
raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def build_data_filter(cfg: CfgNode):
|
| 645 |
+
if cfg.TYPE == "detection_score":
|
| 646 |
+
min_score = cfg.MIN_VALUE
|
| 647 |
+
return ScoreBasedFilter(min_score=min_score)
|
| 648 |
+
raise ValueError(f"Unknown data filter type {cfg.TYPE}")
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def build_inference_based_loader(
|
| 652 |
+
cfg: CfgNode,
|
| 653 |
+
dataset_cfg: CfgNode,
|
| 654 |
+
model: torch.nn.Module,
|
| 655 |
+
embedder: Optional[torch.nn.Module] = None,
|
| 656 |
+
) -> InferenceBasedLoader:
|
| 657 |
+
"""
|
| 658 |
+
Constructs data loader based on inference results of a model.
|
| 659 |
+
"""
|
| 660 |
+
dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
|
| 661 |
+
meta = MetadataCatalog.get(dataset_cfg.DATASET)
|
| 662 |
+
training_sampler = TrainingSampler(len(dataset))
|
| 663 |
+
data_loader = torch.utils.data.DataLoader(
|
| 664 |
+
dataset, # pyre-ignore[6]
|
| 665 |
+
batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
|
| 666 |
+
sampler=training_sampler,
|
| 667 |
+
num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
|
| 668 |
+
collate_fn=trivial_batch_collator,
|
| 669 |
+
worker_init_fn=worker_init_reset_seed,
|
| 670 |
+
)
|
| 671 |
+
return InferenceBasedLoader(
|
| 672 |
+
model,
|
| 673 |
+
data_loader=data_loader,
|
| 674 |
+
data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
|
| 675 |
+
data_filter=build_data_filter(dataset_cfg.FILTER),
|
| 676 |
+
shuffle=True,
|
| 677 |
+
batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
|
| 678 |
+
inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
|
| 679 |
+
category_to_class_mapping=meta.category_to_class_mapping,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def has_inference_based_loaders(cfg: CfgNode) -> bool:
|
| 684 |
+
"""
|
| 685 |
+
Returns True, if at least one inferense-based loader must
|
| 686 |
+
be instantiated for training
|
| 687 |
+
"""
|
| 688 |
+
return len(cfg.BOOTSTRAP_DATASETS) > 0
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def build_inference_based_loaders(
|
| 692 |
+
cfg: CfgNode, model: torch.nn.Module
|
| 693 |
+
) -> Tuple[List[InferenceBasedLoader], List[float]]:
|
| 694 |
+
loaders = []
|
| 695 |
+
ratios = []
|
| 696 |
+
embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
|
| 697 |
+
for dataset_spec in cfg.BOOTSTRAP_DATASETS:
|
| 698 |
+
dataset_cfg = get_bootstrap_dataset_config().clone()
|
| 699 |
+
dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
|
| 700 |
+
loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
|
| 701 |
+
loaders.append(loader)
|
| 702 |
+
ratios.append(dataset_cfg.RATIO)
|
| 703 |
+
return loaders, ratios
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
|
| 707 |
+
video_list_fpath = meta.video_list_fpath
|
| 708 |
+
video_base_path = meta.video_base_path
|
| 709 |
+
category = meta.category
|
| 710 |
+
if cfg.TYPE == "video_keyframe":
|
| 711 |
+
frame_selector = build_frame_selector(cfg.SELECT)
|
| 712 |
+
transform = build_transform(cfg.TRANSFORM, data_type="image")
|
| 713 |
+
video_list = video_list_from_file(video_list_fpath, video_base_path)
|
| 714 |
+
keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
|
| 715 |
+
return VideoKeyframeDataset(
|
| 716 |
+
video_list, category, frame_selector, transform, keyframe_helper_fpath
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class _BootstrapDatasetFactoryCatalog(UserDict):
|
| 721 |
+
"""
|
| 722 |
+
A global dictionary that stores information about bootstrapped datasets creation functions
|
| 723 |
+
from metadata and config, for diverse DatasetType
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
|
| 727 |
+
"""
|
| 728 |
+
Args:
|
| 729 |
+
dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
|
| 730 |
+
factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
|
| 731 |
+
arguments and returns a dataset object.
|
| 732 |
+
"""
|
| 733 |
+
assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
|
| 734 |
+
self[dataset_type] = factory
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
|
| 738 |
+
BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
|
densepose/data/combined_loader.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from collections import deque
|
| 7 |
+
from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
|
| 8 |
+
|
| 9 |
+
Loader = Iterable[Any]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
|
| 13 |
+
if not pool:
|
| 14 |
+
pool.extend(next(iterator))
|
| 15 |
+
return pool.popleft()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CombinedDataLoader:
|
| 19 |
+
"""
|
| 20 |
+
Combines data loaders using the provided sampling ratios
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
BATCH_COUNT = 100
|
| 24 |
+
|
| 25 |
+
def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
|
| 26 |
+
self.loaders = loaders
|
| 27 |
+
self.batch_size = batch_size
|
| 28 |
+
self.ratios = ratios
|
| 29 |
+
|
| 30 |
+
def __iter__(self) -> Iterator[List[Any]]:
|
| 31 |
+
iters = [iter(loader) for loader in self.loaders]
|
| 32 |
+
indices = []
|
| 33 |
+
pool = [deque()] * len(iters)
|
| 34 |
+
# infinite iterator, as in D2
|
| 35 |
+
while True:
|
| 36 |
+
if not indices:
|
| 37 |
+
# just a buffer of indices, its size doesn't matter
|
| 38 |
+
# as long as it's a multiple of batch_size
|
| 39 |
+
k = self.batch_size * self.BATCH_COUNT
|
| 40 |
+
indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
|
| 41 |
+
try:
|
| 42 |
+
batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
|
| 43 |
+
except StopIteration:
|
| 44 |
+
break
|
| 45 |
+
indices = indices[self.batch_size :]
|
| 46 |
+
yield batch
|
densepose/data/dataset_mapper.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
# pyre-unsafe
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Any, Dict, List, Tuple
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from detectron2.data import MetadataCatalog
|
| 12 |
+
from detectron2.data import detection_utils as utils
|
| 13 |
+
from detectron2.data import transforms as T
|
| 14 |
+
from detectron2.layers import ROIAlign
|
| 15 |
+
from detectron2.structures import BoxMode
|
| 16 |
+
from detectron2.utils.file_io import PathManager
|
| 17 |
+
|
| 18 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_augmentation(cfg, is_train):
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
result = utils.build_augmentation(cfg, is_train)
|
| 24 |
+
if is_train:
|
| 25 |
+
random_rotation = T.RandomRotation(
|
| 26 |
+
cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
|
| 27 |
+
)
|
| 28 |
+
result.append(random_rotation)
|
| 29 |
+
logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
|
| 30 |
+
return result
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DatasetMapper:
|
| 34 |
+
"""
|
| 35 |
+
A customized version of `detectron2.data.DatasetMapper`
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, cfg, is_train=True):
|
| 39 |
+
self.augmentation = build_augmentation(cfg, is_train)
|
| 40 |
+
|
| 41 |
+
# fmt: off
|
| 42 |
+
self.img_format = cfg.INPUT.FORMAT
|
| 43 |
+
self.mask_on = (
|
| 44 |
+
cfg.MODEL.MASK_ON or (
|
| 45 |
+
cfg.MODEL.DENSEPOSE_ON
|
| 46 |
+
and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
|
| 47 |
+
)
|
| 48 |
+
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
| 49 |
+
self.densepose_on = cfg.MODEL.DENSEPOSE_ON
|
| 50 |
+
assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
|
| 51 |
+
# fmt: on
|
| 52 |
+
if self.keypoint_on and is_train:
|
| 53 |
+
# Flip only makes sense in training
|
| 54 |
+
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
| 55 |
+
else:
|
| 56 |
+
self.keypoint_hflip_indices = None
|
| 57 |
+
|
| 58 |
+
if self.densepose_on:
|
| 59 |
+
densepose_transform_srcs = [
|
| 60 |
+
MetadataCatalog.get(ds).densepose_transform_src
|
| 61 |
+
for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
|
| 62 |
+
]
|
| 63 |
+
assert len(densepose_transform_srcs) > 0
|
| 64 |
+
# TODO: check that DensePose transformation data is the same for
|
| 65 |
+
# all the datasets. Otherwise one would have to pass DB ID with
|
| 66 |
+
# each entry to select proper transformation data. For now, since
|
| 67 |
+
# all DensePose annotated data uses the same data semantics, we
|
| 68 |
+
# omit this check.
|
| 69 |
+
densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
|
| 70 |
+
self.densepose_transform_data = DensePoseTransformData.load(
|
| 71 |
+
densepose_transform_data_fpath
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.is_train = is_train
|
| 75 |
+
|
| 76 |
+
def __call__(self, dataset_dict):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: a format that builtin models in detectron2 accept
|
| 83 |
+
"""
|
| 84 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 85 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
| 86 |
+
utils.check_image_size(dataset_dict, image)
|
| 87 |
+
|
| 88 |
+
image, transforms = T.apply_transform_gens(self.augmentation, image)
|
| 89 |
+
image_shape = image.shape[:2] # h, w
|
| 90 |
+
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
|
| 91 |
+
|
| 92 |
+
if not self.is_train:
|
| 93 |
+
dataset_dict.pop("annotations", None)
|
| 94 |
+
return dataset_dict
|
| 95 |
+
|
| 96 |
+
for anno in dataset_dict["annotations"]:
|
| 97 |
+
if not self.mask_on:
|
| 98 |
+
anno.pop("segmentation", None)
|
| 99 |
+
if not self.keypoint_on:
|
| 100 |
+
anno.pop("keypoints", None)
|
| 101 |
+
|
| 102 |
+
# USER: Implement additional transformations if you have other types of data
|
| 103 |
+
# USER: Don't call transpose_densepose if you don't need
|
| 104 |
+
annos = [
|
| 105 |
+
self._transform_densepose(
|
| 106 |
+
utils.transform_instance_annotations(
|
| 107 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
| 108 |
+
),
|
| 109 |
+
transforms,
|
| 110 |
+
)
|
| 111 |
+
for obj in dataset_dict.pop("annotations")
|
| 112 |
+
if obj.get("iscrowd", 0) == 0
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
if self.mask_on:
|
| 116 |
+
self._add_densepose_masks_as_segmentation(annos, image_shape)
|
| 117 |
+
|
| 118 |
+
instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
|
| 119 |
+
densepose_annotations = [obj.get("densepose") for obj in annos]
|
| 120 |
+
if densepose_annotations and not all(v is None for v in densepose_annotations):
|
| 121 |
+
instances.gt_densepose = DensePoseList(
|
| 122 |
+
densepose_annotations, instances.gt_boxes, image_shape
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
|
| 126 |
+
return dataset_dict
|
| 127 |
+
|
| 128 |
+
def _transform_densepose(self, annotation, transforms):
|
| 129 |
+
if not self.densepose_on:
|
| 130 |
+
return annotation
|
| 131 |
+
|
| 132 |
+
# Handle densepose annotations
|
| 133 |
+
is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
|
| 134 |
+
if is_valid:
|
| 135 |
+
densepose_data = DensePoseDataRelative(annotation, cleanup=True)
|
| 136 |
+
densepose_data.apply_transform(transforms, self.densepose_transform_data)
|
| 137 |
+
annotation["densepose"] = densepose_data
|
| 138 |
+
else:
|
| 139 |
+
# logger = logging.getLogger(__name__)
|
| 140 |
+
# logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
|
| 141 |
+
DensePoseDataRelative.cleanup_annotation(annotation)
|
| 142 |
+
# NOTE: annotations for certain instances may be unavailable.
|
| 143 |
+
# 'None' is accepted by the DensePostList data structure.
|
| 144 |
+
annotation["densepose"] = None
|
| 145 |
+
return annotation
|
| 146 |
+
|
| 147 |
+
def _add_densepose_masks_as_segmentation(
|
| 148 |
+
self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
|
| 149 |
+
):
|
| 150 |
+
for obj in annotations:
|
| 151 |
+
if ("densepose" not in obj) or ("segmentation" in obj):
|
| 152 |
+
continue
|
| 153 |
+
# DP segmentation: torch.Tensor [S, S] of float32, S=256
|
| 154 |
+
segm_dp = torch.zeros_like(obj["densepose"].segm)
|
| 155 |
+
segm_dp[obj["densepose"].segm > 0] = 1
|
| 156 |
+
segm_h, segm_w = segm_dp.shape
|
| 157 |
+
bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
|
| 158 |
+
# image bbox
|
| 159 |
+
x0, y0, x1, y1 = (
|
| 160 |
+
v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
|
| 161 |
+
)
|
| 162 |
+
segm_aligned = (
|
| 163 |
+
ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
|
| 164 |
+
.forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
|
| 165 |
+
.squeeze()
|
| 166 |
+
)
|
| 167 |
+
image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
|
| 168 |
+
image_mask[y0:y1, x0:x1] = segm_aligned
|
| 169 |
+
# segmentation for BitMask: np.array [H, W] of bool
|
| 170 |
+
obj["segmentation"] = image_mask >= 0.5
|
densepose/data/image_list_dataset.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
# pyre-unsafe
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data.dataset import Dataset
|
| 11 |
+
|
| 12 |
+
from detectron2.data.detection_utils import read_image
|
| 13 |
+
|
| 14 |
+
ImageTransform = Callable[[torch.Tensor], torch.Tensor]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ImageListDataset(Dataset):
|
| 18 |
+
"""
|
| 19 |
+
Dataset that provides images from a list.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
_EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
image_list: List[str],
|
| 27 |
+
category_list: Union[str, List[str], None] = None,
|
| 28 |
+
transform: Optional[ImageTransform] = None,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
image_list (List[str]): list of paths to image files
|
| 33 |
+
category_list (Union[str, List[str], None]): list of animal categories for
|
| 34 |
+
each image. If it is a string, or None, this applies to all images
|
| 35 |
+
"""
|
| 36 |
+
if type(category_list) is list:
|
| 37 |
+
self.category_list = category_list
|
| 38 |
+
else:
|
| 39 |
+
self.category_list = [category_list] * len(image_list)
|
| 40 |
+
assert len(image_list) == len(
|
| 41 |
+
self.category_list
|
| 42 |
+
), "length of image and category lists must be equal"
|
| 43 |
+
self.image_list = image_list
|
| 44 |
+
self.transform = transform
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 47 |
+
"""
|
| 48 |
+
Gets selected images from the list
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
idx (int): video index in the video list file
|
| 52 |
+
Returns:
|
| 53 |
+
A dictionary containing two keys:
|
| 54 |
+
images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
|
| 55 |
+
categories (List[str]): categories of the frames
|
| 56 |
+
"""
|
| 57 |
+
categories = [self.category_list[idx]]
|
| 58 |
+
fpath = self.image_list[idx]
|
| 59 |
+
transform = self.transform
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
|
| 63 |
+
image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
|
| 64 |
+
if transform is not None:
|
| 65 |
+
image = transform(image)
|
| 66 |
+
return {"images": image, "categories": categories}
|
| 67 |
+
except (OSError, RuntimeError) as e:
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
logger.warning(f"Error opening image file container {fpath}: {e}")
|
| 70 |
+
|
| 71 |
+
return {"images": self._EMPTY_IMAGE, "categories": []}
|
| 72 |
+
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self.image_list)
|
densepose/data/inference_based_loader.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
SampledData = Any
|
| 11 |
+
ModelOutput = Any
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
|
| 15 |
+
"""
|
| 16 |
+
Group elements of an iterable by chunks of size `n`, e.g.
|
| 17 |
+
grouper(range(9), 4) ->
|
| 18 |
+
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
|
| 19 |
+
"""
|
| 20 |
+
it = iter(iterable)
|
| 21 |
+
while True:
|
| 22 |
+
values = []
|
| 23 |
+
for _ in range(n):
|
| 24 |
+
try:
|
| 25 |
+
value = next(it)
|
| 26 |
+
except StopIteration:
|
| 27 |
+
if values:
|
| 28 |
+
values.extend([fillvalue] * (n - len(values)))
|
| 29 |
+
yield tuple(values)
|
| 30 |
+
return
|
| 31 |
+
values.append(value)
|
| 32 |
+
yield tuple(values)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ScoreBasedFilter:
|
| 36 |
+
"""
|
| 37 |
+
Filters entries in model output based on their scores
|
| 38 |
+
Discards all entries with score less than the specified minimum
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, min_score: float = 0.8):
|
| 42 |
+
self.min_score = min_score
|
| 43 |
+
|
| 44 |
+
def __call__(self, model_output: ModelOutput) -> ModelOutput:
|
| 45 |
+
for model_output_i in model_output:
|
| 46 |
+
instances = model_output_i["instances"]
|
| 47 |
+
if not instances.has("scores"):
|
| 48 |
+
continue
|
| 49 |
+
instances_filtered = instances[instances.scores >= self.min_score]
|
| 50 |
+
model_output_i["instances"] = instances_filtered
|
| 51 |
+
return model_output
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class InferenceBasedLoader:
|
| 55 |
+
"""
|
| 56 |
+
Data loader based on results inferred by a model. Consists of:
|
| 57 |
+
- a data loader that provides batches of images
|
| 58 |
+
- a model that is used to infer the results
|
| 59 |
+
- a data sampler that converts inferred results to annotations
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
model: nn.Module,
|
| 65 |
+
data_loader: Iterable[List[Dict[str, Any]]],
|
| 66 |
+
data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
|
| 67 |
+
data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
|
| 68 |
+
shuffle: bool = True,
|
| 69 |
+
batch_size: int = 4,
|
| 70 |
+
inference_batch_size: int = 4,
|
| 71 |
+
drop_last: bool = False,
|
| 72 |
+
category_to_class_mapping: Optional[dict] = None,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Constructor
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model (torch.nn.Module): model used to produce data
|
| 79 |
+
data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
|
| 80 |
+
dictionaries with "images" and "categories" fields to perform inference on
|
| 81 |
+
data_sampler (Callable: ModelOutput -> SampledData): functor
|
| 82 |
+
that produces annotation data from inference results;
|
| 83 |
+
(optional, default: None)
|
| 84 |
+
data_filter (Callable: ModelOutput -> ModelOutput): filter
|
| 85 |
+
that selects model outputs for further processing
|
| 86 |
+
(optional, default: None)
|
| 87 |
+
shuffle (bool): if True, the input images get shuffled
|
| 88 |
+
batch_size (int): batch size for the produced annotation data
|
| 89 |
+
inference_batch_size (int): batch size for input images
|
| 90 |
+
drop_last (bool): if True, drop the last batch if it is undersized
|
| 91 |
+
category_to_class_mapping (dict): category to class mapping
|
| 92 |
+
"""
|
| 93 |
+
self.model = model
|
| 94 |
+
self.model.eval()
|
| 95 |
+
self.data_loader = data_loader
|
| 96 |
+
self.data_sampler = data_sampler
|
| 97 |
+
self.data_filter = data_filter
|
| 98 |
+
self.shuffle = shuffle
|
| 99 |
+
self.batch_size = batch_size
|
| 100 |
+
self.inference_batch_size = inference_batch_size
|
| 101 |
+
self.drop_last = drop_last
|
| 102 |
+
if category_to_class_mapping is not None:
|
| 103 |
+
self.category_to_class_mapping = category_to_class_mapping
|
| 104 |
+
else:
|
| 105 |
+
self.category_to_class_mapping = {}
|
| 106 |
+
|
| 107 |
+
def __iter__(self) -> Iterator[List[SampledData]]:
|
| 108 |
+
for batch in self.data_loader:
|
| 109 |
+
# batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
|
| 110 |
+
# images_batch : Tensor[N, C, H, W]
|
| 111 |
+
# image : Tensor[C, H, W]
|
| 112 |
+
images_and_categories = [
|
| 113 |
+
{"image": image, "category": category}
|
| 114 |
+
for element in batch
|
| 115 |
+
for image, category in zip(element["images"], element["categories"])
|
| 116 |
+
]
|
| 117 |
+
if not images_and_categories:
|
| 118 |
+
continue
|
| 119 |
+
if self.shuffle:
|
| 120 |
+
random.shuffle(images_and_categories)
|
| 121 |
+
yield from self._produce_data(images_and_categories) # pyre-ignore[6]
|
| 122 |
+
|
| 123 |
+
def _produce_data(
|
| 124 |
+
self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
|
| 125 |
+
) -> Iterator[List[SampledData]]:
|
| 126 |
+
"""
|
| 127 |
+
Produce batches of data from images
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
|
| 131 |
+
list of images and corresponding categories to process
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Iterator over batches of data sampled from model outputs
|
| 135 |
+
"""
|
| 136 |
+
data_batches: List[SampledData] = []
|
| 137 |
+
category_to_class_mapping = self.category_to_class_mapping
|
| 138 |
+
batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
|
| 139 |
+
for batch in batched_images_and_categories:
|
| 140 |
+
batch = [
|
| 141 |
+
{
|
| 142 |
+
"image": image_and_category["image"].to(self.model.device),
|
| 143 |
+
"category": image_and_category["category"],
|
| 144 |
+
}
|
| 145 |
+
for image_and_category in batch
|
| 146 |
+
if image_and_category is not None
|
| 147 |
+
]
|
| 148 |
+
if not batch:
|
| 149 |
+
continue
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
model_output = self.model(batch)
|
| 152 |
+
for model_output_i, batch_i in zip(model_output, batch):
|
| 153 |
+
assert len(batch_i["image"].shape) == 3
|
| 154 |
+
model_output_i["image"] = batch_i["image"]
|
| 155 |
+
instance_class = category_to_class_mapping.get(batch_i["category"], 0)
|
| 156 |
+
model_output_i["instances"].dataset_classes = torch.tensor(
|
| 157 |
+
[instance_class] * len(model_output_i["instances"])
|
| 158 |
+
)
|
| 159 |
+
model_output_filtered = (
|
| 160 |
+
model_output if self.data_filter is None else self.data_filter(model_output)
|
| 161 |
+
)
|
| 162 |
+
data = (
|
| 163 |
+
model_output_filtered
|
| 164 |
+
if self.data_sampler is None
|
| 165 |
+
else self.data_sampler(model_output_filtered)
|
| 166 |
+
)
|
| 167 |
+
for data_i in data:
|
| 168 |
+
if len(data_i["instances"]):
|
| 169 |
+
data_batches.append(data_i)
|
| 170 |
+
if len(data_batches) >= self.batch_size:
|
| 171 |
+
yield data_batches[: self.batch_size]
|
| 172 |
+
data_batches = data_batches[self.batch_size :]
|
| 173 |
+
if not self.drop_last and data_batches:
|
| 174 |
+
yield data_batches
|
densepose/data/meshes/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from . import builtin
|
| 6 |
+
|
| 7 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
densepose/data/meshes/builtin.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .catalog import MeshInfo, register_meshes
|
| 6 |
+
|
| 7 |
+
DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
|
| 8 |
+
|
| 9 |
+
MESHES = [
|
| 10 |
+
MeshInfo(
|
| 11 |
+
name="smpl_27554",
|
| 12 |
+
data="smpl_27554.pkl",
|
| 13 |
+
geodists="geodists/geodists_smpl_27554.pkl",
|
| 14 |
+
symmetry="symmetry/symmetry_smpl_27554.pkl",
|
| 15 |
+
texcoords="texcoords/texcoords_smpl_27554.pkl",
|
| 16 |
+
),
|
| 17 |
+
MeshInfo(
|
| 18 |
+
name="chimp_5029",
|
| 19 |
+
data="chimp_5029.pkl",
|
| 20 |
+
geodists="geodists/geodists_chimp_5029.pkl",
|
| 21 |
+
symmetry="symmetry/symmetry_chimp_5029.pkl",
|
| 22 |
+
texcoords="texcoords/texcoords_chimp_5029.pkl",
|
| 23 |
+
),
|
| 24 |
+
MeshInfo(
|
| 25 |
+
name="cat_5001",
|
| 26 |
+
data="cat_5001.pkl",
|
| 27 |
+
geodists="geodists/geodists_cat_5001.pkl",
|
| 28 |
+
symmetry="symmetry/symmetry_cat_5001.pkl",
|
| 29 |
+
texcoords="texcoords/texcoords_cat_5001.pkl",
|
| 30 |
+
),
|
| 31 |
+
MeshInfo(
|
| 32 |
+
name="cat_7466",
|
| 33 |
+
data="cat_7466.pkl",
|
| 34 |
+
geodists="geodists/geodists_cat_7466.pkl",
|
| 35 |
+
symmetry="symmetry/symmetry_cat_7466.pkl",
|
| 36 |
+
texcoords="texcoords/texcoords_cat_7466.pkl",
|
| 37 |
+
),
|
| 38 |
+
MeshInfo(
|
| 39 |
+
name="sheep_5004",
|
| 40 |
+
data="sheep_5004.pkl",
|
| 41 |
+
geodists="geodists/geodists_sheep_5004.pkl",
|
| 42 |
+
symmetry="symmetry/symmetry_sheep_5004.pkl",
|
| 43 |
+
texcoords="texcoords/texcoords_sheep_5004.pkl",
|
| 44 |
+
),
|
| 45 |
+
MeshInfo(
|
| 46 |
+
name="zebra_5002",
|
| 47 |
+
data="zebra_5002.pkl",
|
| 48 |
+
geodists="geodists/geodists_zebra_5002.pkl",
|
| 49 |
+
symmetry="symmetry/symmetry_zebra_5002.pkl",
|
| 50 |
+
texcoords="texcoords/texcoords_zebra_5002.pkl",
|
| 51 |
+
),
|
| 52 |
+
MeshInfo(
|
| 53 |
+
name="horse_5004",
|
| 54 |
+
data="horse_5004.pkl",
|
| 55 |
+
geodists="geodists/geodists_horse_5004.pkl",
|
| 56 |
+
symmetry="symmetry/symmetry_horse_5004.pkl",
|
| 57 |
+
texcoords="texcoords/texcoords_zebra_5002.pkl",
|
| 58 |
+
),
|
| 59 |
+
MeshInfo(
|
| 60 |
+
name="giraffe_5002",
|
| 61 |
+
data="giraffe_5002.pkl",
|
| 62 |
+
geodists="geodists/geodists_giraffe_5002.pkl",
|
| 63 |
+
symmetry="symmetry/symmetry_giraffe_5002.pkl",
|
| 64 |
+
texcoords="texcoords/texcoords_giraffe_5002.pkl",
|
| 65 |
+
),
|
| 66 |
+
MeshInfo(
|
| 67 |
+
name="elephant_5002",
|
| 68 |
+
data="elephant_5002.pkl",
|
| 69 |
+
geodists="geodists/geodists_elephant_5002.pkl",
|
| 70 |
+
symmetry="symmetry/symmetry_elephant_5002.pkl",
|
| 71 |
+
texcoords="texcoords/texcoords_elephant_5002.pkl",
|
| 72 |
+
),
|
| 73 |
+
MeshInfo(
|
| 74 |
+
name="dog_5002",
|
| 75 |
+
data="dog_5002.pkl",
|
| 76 |
+
geodists="geodists/geodists_dog_5002.pkl",
|
| 77 |
+
symmetry="symmetry/symmetry_dog_5002.pkl",
|
| 78 |
+
texcoords="texcoords/texcoords_dog_5002.pkl",
|
| 79 |
+
),
|
| 80 |
+
MeshInfo(
|
| 81 |
+
name="dog_7466",
|
| 82 |
+
data="dog_7466.pkl",
|
| 83 |
+
geodists="geodists/geodists_dog_7466.pkl",
|
| 84 |
+
symmetry="symmetry/symmetry_dog_7466.pkl",
|
| 85 |
+
texcoords="texcoords/texcoords_dog_7466.pkl",
|
| 86 |
+
),
|
| 87 |
+
MeshInfo(
|
| 88 |
+
name="cow_5002",
|
| 89 |
+
data="cow_5002.pkl",
|
| 90 |
+
geodists="geodists/geodists_cow_5002.pkl",
|
| 91 |
+
symmetry="symmetry/symmetry_cow_5002.pkl",
|
| 92 |
+
texcoords="texcoords/texcoords_cow_5002.pkl",
|
| 93 |
+
),
|
| 94 |
+
MeshInfo(
|
| 95 |
+
name="bear_4936",
|
| 96 |
+
data="bear_4936.pkl",
|
| 97 |
+
geodists="geodists/geodists_bear_4936.pkl",
|
| 98 |
+
symmetry="symmetry/symmetry_bear_4936.pkl",
|
| 99 |
+
texcoords="texcoords/texcoords_bear_4936.pkl",
|
| 100 |
+
),
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
|
densepose/data/meshes/catalog.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from collections import UserDict
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Iterable, Optional
|
| 9 |
+
|
| 10 |
+
from ..utils import maybe_prepend_base_path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class MeshInfo:
|
| 15 |
+
name: str
|
| 16 |
+
data: str
|
| 17 |
+
geodists: Optional[str] = None
|
| 18 |
+
symmetry: Optional[str] = None
|
| 19 |
+
texcoords: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class _MeshCatalog(UserDict):
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
self.mesh_ids = {}
|
| 26 |
+
self.mesh_names = {}
|
| 27 |
+
self.max_mesh_id = -1
|
| 28 |
+
|
| 29 |
+
def __setitem__(self, key, value):
|
| 30 |
+
if key in self:
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
logger.warning(
|
| 33 |
+
f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
|
| 34 |
+
f", new value {value}"
|
| 35 |
+
)
|
| 36 |
+
mesh_id = self.mesh_ids[key]
|
| 37 |
+
else:
|
| 38 |
+
self.max_mesh_id += 1
|
| 39 |
+
mesh_id = self.max_mesh_id
|
| 40 |
+
super().__setitem__(key, value)
|
| 41 |
+
self.mesh_ids[key] = mesh_id
|
| 42 |
+
self.mesh_names[mesh_id] = key
|
| 43 |
+
|
| 44 |
+
def get_mesh_id(self, shape_name: str) -> int:
|
| 45 |
+
return self.mesh_ids[shape_name]
|
| 46 |
+
|
| 47 |
+
def get_mesh_name(self, mesh_id: int) -> str:
|
| 48 |
+
return self.mesh_names[mesh_id]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
MeshCatalog = _MeshCatalog()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
|
| 55 |
+
geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
|
| 56 |
+
if geodists:
|
| 57 |
+
geodists = maybe_prepend_base_path(base_path, geodists)
|
| 58 |
+
if symmetry:
|
| 59 |
+
symmetry = maybe_prepend_base_path(base_path, symmetry)
|
| 60 |
+
if texcoords:
|
| 61 |
+
texcoords = maybe_prepend_base_path(base_path, texcoords)
|
| 62 |
+
MeshCatalog[mesh_info.name] = MeshInfo(
|
| 63 |
+
name=mesh_info.name,
|
| 64 |
+
data=maybe_prepend_base_path(base_path, mesh_info.data),
|
| 65 |
+
geodists=geodists,
|
| 66 |
+
symmetry=symmetry,
|
| 67 |
+
texcoords=texcoords,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
|
| 72 |
+
for mesh_info in mesh_infos:
|
| 73 |
+
register_mesh(mesh_info, base_path)
|
densepose/data/samplers/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .densepose_uniform import DensePoseUniformSampler
|
| 6 |
+
from .densepose_confidence_based import DensePoseConfidenceBasedSampler
|
| 7 |
+
from .densepose_cse_uniform import DensePoseCSEUniformSampler
|
| 8 |
+
from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
|
| 9 |
+
from .mask_from_densepose import MaskFromDensePoseSampler
|
| 10 |
+
from .prediction_to_gt import PredictionToGroundTruthSampler
|
densepose/data/samplers/densepose_base.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BoxMode, Instances
|
| 10 |
+
|
| 11 |
+
from densepose.converters import ToChartResultConverter
|
| 12 |
+
from densepose.converters.base import IntTupleBox, make_int_box
|
| 13 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DensePoseBaseSampler:
|
| 17 |
+
"""
|
| 18 |
+
Base DensePose sampler to produce DensePose data from DensePose predictions.
|
| 19 |
+
Samples for each class are drawn according to some distribution over all pixels estimated
|
| 20 |
+
to belong to that class.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, count_per_class: int = 8):
|
| 24 |
+
"""
|
| 25 |
+
Constructor
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 29 |
+
samples for each category
|
| 30 |
+
"""
|
| 31 |
+
self.count_per_class = count_per_class
|
| 32 |
+
|
| 33 |
+
def __call__(self, instances: Instances) -> DensePoseList:
|
| 34 |
+
"""
|
| 35 |
+
Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
|
| 36 |
+
into DensePose annotations data (an instance of `DensePoseList`)
|
| 37 |
+
"""
|
| 38 |
+
boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
|
| 39 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 40 |
+
dp_datas = []
|
| 41 |
+
for i in range(len(boxes_xywh_abs)):
|
| 42 |
+
annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
|
| 43 |
+
annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
|
| 44 |
+
instances[i].pred_densepose
|
| 45 |
+
)
|
| 46 |
+
dp_datas.append(DensePoseDataRelative(annotation_i))
|
| 47 |
+
# create densepose annotations on CPU
|
| 48 |
+
dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
|
| 49 |
+
return dp_list
|
| 50 |
+
|
| 51 |
+
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
| 52 |
+
"""
|
| 53 |
+
Sample DensPoseDataRelative from estimation results
|
| 54 |
+
"""
|
| 55 |
+
labels, dp_result = self._produce_labels_and_results(instance)
|
| 56 |
+
annotation = {
|
| 57 |
+
DensePoseDataRelative.X_KEY: [],
|
| 58 |
+
DensePoseDataRelative.Y_KEY: [],
|
| 59 |
+
DensePoseDataRelative.U_KEY: [],
|
| 60 |
+
DensePoseDataRelative.V_KEY: [],
|
| 61 |
+
DensePoseDataRelative.I_KEY: [],
|
| 62 |
+
}
|
| 63 |
+
n, h, w = dp_result.shape
|
| 64 |
+
for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
|
| 65 |
+
# indices - tuple of 3 1D tensors of size k
|
| 66 |
+
# 0: index along the first dimension N
|
| 67 |
+
# 1: index along H dimension
|
| 68 |
+
# 2: index along W dimension
|
| 69 |
+
indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
|
| 70 |
+
# values - an array of size [n, k]
|
| 71 |
+
# n: number of channels (U, V, confidences)
|
| 72 |
+
# k: number of points labeled with part_id
|
| 73 |
+
values = dp_result[indices].view(n, -1)
|
| 74 |
+
k = values.shape[1]
|
| 75 |
+
count = min(self.count_per_class, k)
|
| 76 |
+
if count <= 0:
|
| 77 |
+
continue
|
| 78 |
+
index_sample = self._produce_index_sample(values, count)
|
| 79 |
+
sampled_values = values[:, index_sample]
|
| 80 |
+
sampled_y = indices[1][index_sample] + 0.5
|
| 81 |
+
sampled_x = indices[2][index_sample] + 0.5
|
| 82 |
+
# prepare / normalize data
|
| 83 |
+
x = (sampled_x / w * 256.0).cpu().tolist()
|
| 84 |
+
y = (sampled_y / h * 256.0).cpu().tolist()
|
| 85 |
+
u = sampled_values[0].clamp(0, 1).cpu().tolist()
|
| 86 |
+
v = sampled_values[1].clamp(0, 1).cpu().tolist()
|
| 87 |
+
fine_segm_labels = [part_id] * count
|
| 88 |
+
# extend annotations
|
| 89 |
+
annotation[DensePoseDataRelative.X_KEY].extend(x)
|
| 90 |
+
annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
| 91 |
+
annotation[DensePoseDataRelative.U_KEY].extend(u)
|
| 92 |
+
annotation[DensePoseDataRelative.V_KEY].extend(v)
|
| 93 |
+
annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
|
| 94 |
+
return annotation
|
| 95 |
+
|
| 96 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 97 |
+
"""
|
| 98 |
+
Abstract method to produce a sample of indices to select data
|
| 99 |
+
To be implemented in descendants
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
| 103 |
+
estimated values (U, V, confidences);
|
| 104 |
+
n: number of channels (U, V, confidences)
|
| 105 |
+
k: number of points labeled with part_id
|
| 106 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 107 |
+
|
| 108 |
+
Return:
|
| 109 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 110 |
+
"""
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 114 |
+
"""
|
| 115 |
+
Method to get labels and DensePose results from an instance
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutput`
|
| 119 |
+
|
| 120 |
+
Return:
|
| 121 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
| 122 |
+
dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
|
| 123 |
+
"""
|
| 124 |
+
converter = ToChartResultConverter
|
| 125 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
| 126 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
| 127 |
+
return labels, dp_result
|
| 128 |
+
|
| 129 |
+
def _resample_mask(self, output: Any) -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
Convert DensePose predictor output to segmentation annotation - tensors of size
|
| 132 |
+
(256, 256) and type `int64`.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
output: DensePose predictor output with the following attributes:
|
| 136 |
+
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
| 137 |
+
segmentation scores
|
| 138 |
+
- fine_segm: tensor of size [N, C, H, W] with unnormalized fine
|
| 139 |
+
segmentation scores
|
| 140 |
+
Return:
|
| 141 |
+
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
| 142 |
+
where S = DensePoseDataRelative.MASK_SIZE
|
| 143 |
+
"""
|
| 144 |
+
sz = DensePoseDataRelative.MASK_SIZE
|
| 145 |
+
S = (
|
| 146 |
+
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
| 147 |
+
.argmax(dim=1)
|
| 148 |
+
.long()
|
| 149 |
+
)
|
| 150 |
+
I = (
|
| 151 |
+
(
|
| 152 |
+
F.interpolate(
|
| 153 |
+
output.fine_segm,
|
| 154 |
+
(sz, sz),
|
| 155 |
+
mode="bilinear",
|
| 156 |
+
align_corners=False,
|
| 157 |
+
).argmax(dim=1)
|
| 158 |
+
* (S > 0).long()
|
| 159 |
+
)
|
| 160 |
+
.squeeze()
|
| 161 |
+
.cpu()
|
| 162 |
+
)
|
| 163 |
+
# Map fine segmentation results to coarse segmentation ground truth
|
| 164 |
+
# TODO: extract this into separate classes
|
| 165 |
+
# coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
|
| 166 |
+
# 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
|
| 167 |
+
# 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
|
| 168 |
+
# 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
|
| 169 |
+
# 14 = Head
|
| 170 |
+
# fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
|
| 171 |
+
# 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
|
| 172 |
+
# 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
|
| 173 |
+
# 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
|
| 174 |
+
# 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
|
| 175 |
+
# 20, 22 = Lower Arm Right, 23, 24 = Head
|
| 176 |
+
FINE_TO_COARSE_SEGMENTATION = {
|
| 177 |
+
1: 1,
|
| 178 |
+
2: 1,
|
| 179 |
+
3: 2,
|
| 180 |
+
4: 3,
|
| 181 |
+
5: 4,
|
| 182 |
+
6: 5,
|
| 183 |
+
7: 6,
|
| 184 |
+
8: 7,
|
| 185 |
+
9: 6,
|
| 186 |
+
10: 7,
|
| 187 |
+
11: 8,
|
| 188 |
+
12: 9,
|
| 189 |
+
13: 8,
|
| 190 |
+
14: 9,
|
| 191 |
+
15: 10,
|
| 192 |
+
16: 11,
|
| 193 |
+
17: 10,
|
| 194 |
+
18: 11,
|
| 195 |
+
19: 12,
|
| 196 |
+
20: 13,
|
| 197 |
+
21: 12,
|
| 198 |
+
22: 13,
|
| 199 |
+
23: 14,
|
| 200 |
+
24: 14,
|
| 201 |
+
}
|
| 202 |
+
mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
|
| 203 |
+
for i in range(DensePoseDataRelative.N_PART_LABELS):
|
| 204 |
+
mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
|
| 205 |
+
return mask
|
densepose/data/samplers/densepose_confidence_based.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from densepose.converters import ToChartResultConverterWithConfidences
|
| 10 |
+
|
| 11 |
+
from .densepose_base import DensePoseBaseSampler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
|
| 15 |
+
"""
|
| 16 |
+
Samples DensePose data from DensePose predictions.
|
| 17 |
+
Samples for each class are drawn using confidence value estimates.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
confidence_channel: str,
|
| 23 |
+
count_per_class: int = 8,
|
| 24 |
+
search_count_multiplier: Optional[float] = None,
|
| 25 |
+
search_proportion: Optional[float] = None,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Constructor
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
confidence_channel (str): confidence channel to use for sampling;
|
| 32 |
+
possible values:
|
| 33 |
+
"sigma_2": confidences for UV values
|
| 34 |
+
"fine_segm_confidence": confidences for fine segmentation
|
| 35 |
+
"coarse_segm_confidence": confidences for coarse segmentation
|
| 36 |
+
(default: "sigma_2")
|
| 37 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 38 |
+
samples for each category (default: 8)
|
| 39 |
+
search_count_multiplier (float or None): if not None, the total number
|
| 40 |
+
of the most confident estimates of a given class to consider is
|
| 41 |
+
defined as `min(search_count_multiplier * count_per_class, N)`,
|
| 42 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 43 |
+
specified together with `search_proportion` (default: None)
|
| 44 |
+
search_proportion (float or None): if not None, the total number of the
|
| 45 |
+
of the most confident estimates of a given class to consider is
|
| 46 |
+
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
| 47 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 48 |
+
specified together with `search_count_multiplier` (default: None)
|
| 49 |
+
"""
|
| 50 |
+
super().__init__(count_per_class)
|
| 51 |
+
self.confidence_channel = confidence_channel
|
| 52 |
+
self.search_count_multiplier = search_count_multiplier
|
| 53 |
+
self.search_proportion = search_proportion
|
| 54 |
+
assert (search_count_multiplier is None) or (search_proportion is None), (
|
| 55 |
+
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
| 56 |
+
f"and search_proportion (={search_proportion})"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 60 |
+
"""
|
| 61 |
+
Produce a sample of indices to select data based on confidences
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
| 65 |
+
estimated values (U, V, confidences);
|
| 66 |
+
n: number of channels (U, V, confidences)
|
| 67 |
+
k: number of points labeled with part_id
|
| 68 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 69 |
+
|
| 70 |
+
Return:
|
| 71 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 72 |
+
"""
|
| 73 |
+
k = values.shape[1]
|
| 74 |
+
if k == count:
|
| 75 |
+
index_sample = list(range(k))
|
| 76 |
+
else:
|
| 77 |
+
# take the best count * search_count_multiplier pixels,
|
| 78 |
+
# sample from them uniformly
|
| 79 |
+
# (here best = smallest variance)
|
| 80 |
+
_, sorted_confidence_indices = torch.sort(values[2])
|
| 81 |
+
if self.search_count_multiplier is not None:
|
| 82 |
+
search_count = min(int(count * self.search_count_multiplier), k)
|
| 83 |
+
elif self.search_proportion is not None:
|
| 84 |
+
search_count = min(max(int(k * self.search_proportion), count), k)
|
| 85 |
+
else:
|
| 86 |
+
search_count = min(count, k)
|
| 87 |
+
sample_from_top = random.sample(range(search_count), count)
|
| 88 |
+
index_sample = sorted_confidence_indices[:search_count][sample_from_top]
|
| 89 |
+
return index_sample
|
| 90 |
+
|
| 91 |
+
def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Method to get labels and DensePose results from an instance, with confidences
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
|
| 97 |
+
|
| 98 |
+
Return:
|
| 99 |
+
labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
| 100 |
+
dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
|
| 101 |
+
stacked with the confidence channel
|
| 102 |
+
"""
|
| 103 |
+
converter = ToChartResultConverterWithConfidences
|
| 104 |
+
chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
| 105 |
+
labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
| 106 |
+
dp_result = torch.cat(
|
| 107 |
+
(dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return labels, dp_result
|
densepose/data/samplers/densepose_cse_base.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.config import CfgNode
|
| 10 |
+
from detectron2.structures import Instances
|
| 11 |
+
|
| 12 |
+
from densepose.converters.base import IntTupleBox
|
| 13 |
+
from densepose.data.utils import get_class_to_mesh_name_mapping
|
| 14 |
+
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
|
| 15 |
+
from densepose.structures import DensePoseDataRelative
|
| 16 |
+
|
| 17 |
+
from .densepose_base import DensePoseBaseSampler
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DensePoseCSEBaseSampler(DensePoseBaseSampler):
|
| 21 |
+
"""
|
| 22 |
+
Base DensePose sampler to produce DensePose data from DensePose predictions.
|
| 23 |
+
Samples for each class are drawn according to some distribution over all pixels estimated
|
| 24 |
+
to belong to that class.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
cfg: CfgNode,
|
| 30 |
+
use_gt_categories: bool,
|
| 31 |
+
embedder: torch.nn.Module,
|
| 32 |
+
count_per_class: int = 8,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Constructor
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
cfg (CfgNode): the config of the model
|
| 39 |
+
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
| 40 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 41 |
+
samples for each category
|
| 42 |
+
"""
|
| 43 |
+
super().__init__(count_per_class)
|
| 44 |
+
self.embedder = embedder
|
| 45 |
+
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
| 46 |
+
self.use_gt_categories = use_gt_categories
|
| 47 |
+
|
| 48 |
+
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
| 49 |
+
"""
|
| 50 |
+
Sample DensPoseDataRelative from estimation results
|
| 51 |
+
"""
|
| 52 |
+
if self.use_gt_categories:
|
| 53 |
+
instance_class = instance.dataset_classes.tolist()[0]
|
| 54 |
+
else:
|
| 55 |
+
instance_class = instance.pred_classes.tolist()[0]
|
| 56 |
+
mesh_name = self.class_to_mesh_name[instance_class]
|
| 57 |
+
|
| 58 |
+
annotation = {
|
| 59 |
+
DensePoseDataRelative.X_KEY: [],
|
| 60 |
+
DensePoseDataRelative.Y_KEY: [],
|
| 61 |
+
DensePoseDataRelative.VERTEX_IDS_KEY: [],
|
| 62 |
+
DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
|
| 66 |
+
indices = torch.nonzero(mask, as_tuple=True)
|
| 67 |
+
selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
|
| 68 |
+
values = other_values[:, indices[0], indices[1]]
|
| 69 |
+
k = values.shape[1]
|
| 70 |
+
|
| 71 |
+
count = min(self.count_per_class, k)
|
| 72 |
+
if count <= 0:
|
| 73 |
+
return annotation
|
| 74 |
+
|
| 75 |
+
index_sample = self._produce_index_sample(values, count)
|
| 76 |
+
closest_vertices = squared_euclidean_distance_matrix(
|
| 77 |
+
selected_embeddings[index_sample], self.embedder(mesh_name)
|
| 78 |
+
)
|
| 79 |
+
closest_vertices = torch.argmin(closest_vertices, dim=1)
|
| 80 |
+
|
| 81 |
+
sampled_y = indices[0][index_sample] + 0.5
|
| 82 |
+
sampled_x = indices[1][index_sample] + 0.5
|
| 83 |
+
# prepare / normalize data
|
| 84 |
+
_, _, w, h = bbox_xywh
|
| 85 |
+
x = (sampled_x / w * 256.0).cpu().tolist()
|
| 86 |
+
y = (sampled_y / h * 256.0).cpu().tolist()
|
| 87 |
+
# extend annotations
|
| 88 |
+
annotation[DensePoseDataRelative.X_KEY].extend(x)
|
| 89 |
+
annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
| 90 |
+
annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
|
| 91 |
+
return annotation
|
| 92 |
+
|
| 93 |
+
def _produce_mask_and_results(
|
| 94 |
+
self, instance: Instances, bbox_xywh: IntTupleBox
|
| 95 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 96 |
+
"""
|
| 97 |
+
Method to get labels and DensePose results from an instance
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
|
| 101 |
+
bbox_xywh (IntTupleBox): the corresponding bounding box
|
| 102 |
+
|
| 103 |
+
Return:
|
| 104 |
+
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
| 105 |
+
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
|
| 106 |
+
DensePose CSE Embeddings
|
| 107 |
+
other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
|
| 108 |
+
for potential other values
|
| 109 |
+
"""
|
| 110 |
+
densepose_output = instance.pred_densepose
|
| 111 |
+
S = densepose_output.coarse_segm
|
| 112 |
+
E = densepose_output.embedding
|
| 113 |
+
_, _, w, h = bbox_xywh
|
| 114 |
+
embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
|
| 115 |
+
coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
|
| 116 |
+
mask = coarse_segm_resized.argmax(0) > 0
|
| 117 |
+
other_values = torch.empty((0, h, w), device=E.device)
|
| 118 |
+
return mask, embeddings, other_values
|
| 119 |
+
|
| 120 |
+
def _resample_mask(self, output: Any) -> torch.Tensor:
|
| 121 |
+
"""
|
| 122 |
+
Convert DensePose predictor output to segmentation annotation - tensors of size
|
| 123 |
+
(256, 256) and type `int64`.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
output: DensePose predictor output with the following attributes:
|
| 127 |
+
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
| 128 |
+
segmentation scores
|
| 129 |
+
Return:
|
| 130 |
+
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
| 131 |
+
where S = DensePoseDataRelative.MASK_SIZE
|
| 132 |
+
"""
|
| 133 |
+
sz = DensePoseDataRelative.MASK_SIZE
|
| 134 |
+
mask = (
|
| 135 |
+
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
| 136 |
+
.argmax(dim=1)
|
| 137 |
+
.long()
|
| 138 |
+
.squeeze()
|
| 139 |
+
.cpu()
|
| 140 |
+
)
|
| 141 |
+
return mask
|
densepose/data/samplers/densepose_cse_confidence_based.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
from detectron2.config import CfgNode
|
| 11 |
+
from detectron2.structures import Instances
|
| 12 |
+
|
| 13 |
+
from densepose.converters.base import IntTupleBox
|
| 14 |
+
|
| 15 |
+
from .densepose_cse_base import DensePoseCSEBaseSampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
|
| 19 |
+
"""
|
| 20 |
+
Samples DensePose data from DensePose predictions.
|
| 21 |
+
Samples for each class are drawn using confidence value estimates.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
cfg: CfgNode,
|
| 27 |
+
use_gt_categories: bool,
|
| 28 |
+
embedder: torch.nn.Module,
|
| 29 |
+
confidence_channel: str,
|
| 30 |
+
count_per_class: int = 8,
|
| 31 |
+
search_count_multiplier: Optional[float] = None,
|
| 32 |
+
search_proportion: Optional[float] = None,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Constructor
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
cfg (CfgNode): the config of the model
|
| 39 |
+
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
| 40 |
+
confidence_channel (str): confidence channel to use for sampling;
|
| 41 |
+
possible values:
|
| 42 |
+
"coarse_segm_confidence": confidences for coarse segmentation
|
| 43 |
+
(default: "coarse_segm_confidence")
|
| 44 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 45 |
+
samples for each category (default: 8)
|
| 46 |
+
search_count_multiplier (float or None): if not None, the total number
|
| 47 |
+
of the most confident estimates of a given class to consider is
|
| 48 |
+
defined as `min(search_count_multiplier * count_per_class, N)`,
|
| 49 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 50 |
+
specified together with `search_proportion` (default: None)
|
| 51 |
+
search_proportion (float or None): if not None, the total number of the
|
| 52 |
+
of the most confident estimates of a given class to consider is
|
| 53 |
+
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
| 54 |
+
where `N` is the total number of estimates of the class; cannot be
|
| 55 |
+
specified together with `search_count_multiplier` (default: None)
|
| 56 |
+
"""
|
| 57 |
+
super().__init__(cfg, use_gt_categories, embedder, count_per_class)
|
| 58 |
+
self.confidence_channel = confidence_channel
|
| 59 |
+
self.search_count_multiplier = search_count_multiplier
|
| 60 |
+
self.search_proportion = search_proportion
|
| 61 |
+
assert (search_count_multiplier is None) or (search_proportion is None), (
|
| 62 |
+
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
| 63 |
+
f"and search_proportion (={search_proportion})"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 67 |
+
"""
|
| 68 |
+
Produce a sample of indices to select data based on confidences
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
values (torch.Tensor): a tensor of length k that contains confidences
|
| 72 |
+
k: number of points labeled with part_id
|
| 73 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 74 |
+
|
| 75 |
+
Return:
|
| 76 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 77 |
+
"""
|
| 78 |
+
k = values.shape[1]
|
| 79 |
+
if k == count:
|
| 80 |
+
index_sample = list(range(k))
|
| 81 |
+
else:
|
| 82 |
+
# take the best count * search_count_multiplier pixels,
|
| 83 |
+
# sample from them uniformly
|
| 84 |
+
# (here best = smallest variance)
|
| 85 |
+
_, sorted_confidence_indices = torch.sort(values[0])
|
| 86 |
+
if self.search_count_multiplier is not None:
|
| 87 |
+
search_count = min(int(count * self.search_count_multiplier), k)
|
| 88 |
+
elif self.search_proportion is not None:
|
| 89 |
+
search_count = min(max(int(k * self.search_proportion), count), k)
|
| 90 |
+
else:
|
| 91 |
+
search_count = min(count, k)
|
| 92 |
+
sample_from_top = random.sample(range(search_count), count)
|
| 93 |
+
index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
|
| 94 |
+
return index_sample
|
| 95 |
+
|
| 96 |
+
def _produce_mask_and_results(
|
| 97 |
+
self, instance: Instances, bbox_xywh: IntTupleBox
|
| 98 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 99 |
+
"""
|
| 100 |
+
Method to get labels and DensePose results from an instance
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
instance (Instances): an instance of
|
| 104 |
+
`DensePoseEmbeddingPredictorOutputWithConfidences`
|
| 105 |
+
bbox_xywh (IntTupleBox): the corresponding bounding box
|
| 106 |
+
|
| 107 |
+
Return:
|
| 108 |
+
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
| 109 |
+
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
|
| 110 |
+
DensePose CSE Embeddings
|
| 111 |
+
other_values: a tensor of shape [1, H, W], DensePose CSE confidence
|
| 112 |
+
"""
|
| 113 |
+
_, _, w, h = bbox_xywh
|
| 114 |
+
densepose_output = instance.pred_densepose
|
| 115 |
+
mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
|
| 116 |
+
other_values = F.interpolate(
|
| 117 |
+
getattr(densepose_output, self.confidence_channel),
|
| 118 |
+
size=(h, w),
|
| 119 |
+
mode="bilinear",
|
| 120 |
+
)[0].cpu()
|
| 121 |
+
return mask, embeddings, other_values
|
densepose/data/samplers/densepose_cse_uniform.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .densepose_cse_base import DensePoseCSEBaseSampler
|
| 6 |
+
from .densepose_uniform import DensePoseUniformSampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
|
| 10 |
+
"""
|
| 11 |
+
Uniform Sampler for CSE
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
pass
|
densepose/data/samplers/densepose_uniform.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .densepose_base import DensePoseBaseSampler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DensePoseUniformSampler(DensePoseBaseSampler):
|
| 12 |
+
"""
|
| 13 |
+
Samples DensePose data from DensePose predictions.
|
| 14 |
+
Samples for each class are drawn uniformly over all pixels estimated
|
| 15 |
+
to belong to that class.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, count_per_class: int = 8):
|
| 19 |
+
"""
|
| 20 |
+
Constructor
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
count_per_class (int): the sampler produces at most `count_per_class`
|
| 24 |
+
samples for each category
|
| 25 |
+
"""
|
| 26 |
+
super().__init__(count_per_class)
|
| 27 |
+
|
| 28 |
+
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| 29 |
+
"""
|
| 30 |
+
Produce a uniform sample of indices to select data
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
values (torch.Tensor): an array of size [n, k] that contains
|
| 34 |
+
estimated values (U, V, confidences);
|
| 35 |
+
n: number of channels (U, V, confidences)
|
| 36 |
+
k: number of points labeled with part_id
|
| 37 |
+
count (int): number of samples to produce, should be positive and <= k
|
| 38 |
+
|
| 39 |
+
Return:
|
| 40 |
+
list(int): indices of values (along axis 1) selected as a sample
|
| 41 |
+
"""
|
| 42 |
+
k = values.shape[1]
|
| 43 |
+
return random.sample(range(k), count)
|
densepose/data/samplers/mask_from_densepose.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from detectron2.structures import BitMasks, Instances
|
| 6 |
+
|
| 7 |
+
from densepose.converters import ToMaskConverter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MaskFromDensePoseSampler:
|
| 11 |
+
"""
|
| 12 |
+
Produce mask GT from DensePose predictions
|
| 13 |
+
This sampler simply converts DensePose predictions to BitMasks
|
| 14 |
+
that a contain a bool tensor of the size of the input image
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __call__(self, instances: Instances) -> BitMasks:
|
| 18 |
+
"""
|
| 19 |
+
Converts predicted data from `instances` into the GT mask data
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
instances (Instances): predicted results, expected to have `pred_densepose` field
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Boolean Tensor of the size of the input image that has non-zero
|
| 26 |
+
values at pixels that are estimated to belong to the detected object
|
| 27 |
+
"""
|
| 28 |
+
return ToMaskConverter.convert(
|
| 29 |
+
instances.pred_densepose, instances.pred_boxes, instances.image_size
|
| 30 |
+
)
|
densepose/data/samplers/prediction_to_gt.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from detectron2.structures import Instances
|
| 9 |
+
|
| 10 |
+
ModelOutput = Dict[str, Any]
|
| 11 |
+
SampledData = Dict[str, Any]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class _Sampler:
|
| 16 |
+
"""
|
| 17 |
+
Sampler registry entry that contains:
|
| 18 |
+
- src (str): source field to sample from (deleted after sampling)
|
| 19 |
+
- dst (Optional[str]): destination field to sample to, if not None
|
| 20 |
+
- func (Optional[Callable: Any -> Any]): function that performs sampling,
|
| 21 |
+
if None, reference copy is performed
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
src: str
|
| 25 |
+
dst: Optional[str]
|
| 26 |
+
func: Optional[Callable[[Any], Any]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PredictionToGroundTruthSampler:
|
| 30 |
+
"""
|
| 31 |
+
Sampler implementation that converts predictions to GT using registered
|
| 32 |
+
samplers for different fields of `Instances`.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, dataset_name: str = ""):
|
| 36 |
+
self.dataset_name = dataset_name
|
| 37 |
+
self._samplers = {}
|
| 38 |
+
self.register_sampler("pred_boxes", "gt_boxes", None)
|
| 39 |
+
self.register_sampler("pred_classes", "gt_classes", None)
|
| 40 |
+
# delete scores
|
| 41 |
+
self.register_sampler("scores")
|
| 42 |
+
|
| 43 |
+
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
|
| 44 |
+
"""
|
| 45 |
+
Transform model output into ground truth data through sampling
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_output (Dict[str, Any]): model output
|
| 49 |
+
Returns:
|
| 50 |
+
Dict[str, Any]: sampled data
|
| 51 |
+
"""
|
| 52 |
+
for model_output_i in model_output:
|
| 53 |
+
instances: Instances = model_output_i["instances"]
|
| 54 |
+
# transform data in each field
|
| 55 |
+
for _, sampler in self._samplers.items():
|
| 56 |
+
if not instances.has(sampler.src) or sampler.dst is None:
|
| 57 |
+
continue
|
| 58 |
+
if sampler.func is None:
|
| 59 |
+
instances.set(sampler.dst, instances.get(sampler.src))
|
| 60 |
+
else:
|
| 61 |
+
instances.set(sampler.dst, sampler.func(instances))
|
| 62 |
+
# delete model output data that was transformed
|
| 63 |
+
for _, sampler in self._samplers.items():
|
| 64 |
+
if sampler.src != sampler.dst and instances.has(sampler.src):
|
| 65 |
+
instances.remove(sampler.src)
|
| 66 |
+
model_output_i["dataset"] = self.dataset_name
|
| 67 |
+
return model_output
|
| 68 |
+
|
| 69 |
+
def register_sampler(
|
| 70 |
+
self,
|
| 71 |
+
prediction_attr: str,
|
| 72 |
+
gt_attr: Optional[str] = None,
|
| 73 |
+
func: Optional[Callable[[Any], Any]] = None,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Register sampler for a field
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
prediction_attr (str): field to replace with a sampled value
|
| 80 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
| 81 |
+
func (Optional[Callable: Any -> Any]): sampler function
|
| 82 |
+
"""
|
| 83 |
+
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
|
| 84 |
+
src=prediction_attr, dst=gt_attr, func=func
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def remove_sampler(
|
| 88 |
+
self,
|
| 89 |
+
prediction_attr: str,
|
| 90 |
+
gt_attr: Optional[str] = None,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Remove sampler for a field
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
prediction_attr (str): field to replace with a sampled value
|
| 97 |
+
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
| 98 |
+
"""
|
| 99 |
+
assert (prediction_attr, gt_attr) in self._samplers
|
| 100 |
+
del self._samplers[(prediction_attr, gt_attr)]
|
densepose/data/transform/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .image import ImageResizeTransform
|
densepose/data/transform/image.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ImageResizeTransform:
|
| 9 |
+
"""
|
| 10 |
+
Transform that resizes images loaded from a dataset
|
| 11 |
+
(BGR data in NCHW channel order, typically uint8) to a format ready to be
|
| 12 |
+
consumed by DensePose training (BGR float32 data in NCHW channel order)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, min_size: int = 800, max_size: int = 1333):
|
| 16 |
+
self.min_size = min_size
|
| 17 |
+
self.max_size = max_size
|
| 18 |
+
|
| 19 |
+
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
images (torch.Tensor): tensor of size [N, 3, H, W] that contains
|
| 23 |
+
BGR data (typically in uint8)
|
| 24 |
+
Returns:
|
| 25 |
+
images (torch.Tensor): tensor of size [N, 3, H1, W1] where
|
| 26 |
+
H1 and W1 are chosen to respect the specified min and max sizes
|
| 27 |
+
and preserve the original aspect ratio, the data channels
|
| 28 |
+
follow BGR order and the data type is `torch.float32`
|
| 29 |
+
"""
|
| 30 |
+
# resize with min size
|
| 31 |
+
images = images.float()
|
| 32 |
+
min_size = min(images.shape[-2:])
|
| 33 |
+
max_size = max(images.shape[-2:])
|
| 34 |
+
scale = min(self.min_size / min_size, self.max_size / max_size)
|
| 35 |
+
images = torch.nn.functional.interpolate(
|
| 36 |
+
images,
|
| 37 |
+
scale_factor=scale,
|
| 38 |
+
mode="bilinear",
|
| 39 |
+
align_corners=False,
|
| 40 |
+
)
|
| 41 |
+
return images
|
densepose/data/utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
|
| 8 |
+
from detectron2.config import CfgNode
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def is_relative_local_path(path: str) -> bool:
|
| 12 |
+
path_str = os.fsdecode(path)
|
| 13 |
+
return ("://" not in path_str) and not os.path.isabs(path)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def maybe_prepend_base_path(base_path: Optional[str], path: str):
|
| 17 |
+
"""
|
| 18 |
+
Prepends the provided path with a base path prefix if:
|
| 19 |
+
1) base path is not None;
|
| 20 |
+
2) path is a local path
|
| 21 |
+
"""
|
| 22 |
+
if base_path is None:
|
| 23 |
+
return path
|
| 24 |
+
if is_relative_local_path(path):
|
| 25 |
+
return os.path.join(base_path, path)
|
| 26 |
+
return path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
|
| 30 |
+
return {
|
| 31 |
+
int(class_id): mesh_name
|
| 32 |
+
for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
|
| 37 |
+
return {
|
| 38 |
+
category: int(class_id)
|
| 39 |
+
for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
|
| 40 |
+
}
|
densepose/data/video/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .frame_selector import (
|
| 6 |
+
FrameSelectionStrategy,
|
| 7 |
+
RandomKFramesSelector,
|
| 8 |
+
FirstKFramesSelector,
|
| 9 |
+
LastKFramesSelector,
|
| 10 |
+
FrameTsList,
|
| 11 |
+
FrameSelector,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .video_keyframe_dataset import (
|
| 15 |
+
VideoKeyframeDataset,
|
| 16 |
+
video_list_from_file,
|
| 17 |
+
list_keyframes,
|
| 18 |
+
read_keyframes,
|
| 19 |
+
)
|
densepose/data/video/frame_selector.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Callable as TCallable
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
FrameTsList = List[int]
|
| 12 |
+
FrameSelector = TCallable[[FrameTsList], FrameTsList]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FrameSelectionStrategy(Enum):
|
| 16 |
+
"""
|
| 17 |
+
Frame selection strategy used with videos:
|
| 18 |
+
- "random_k": select k random frames
|
| 19 |
+
- "first_k": select k first frames
|
| 20 |
+
- "last_k": select k last frames
|
| 21 |
+
- "all": select all frames
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# fmt: off
|
| 25 |
+
RANDOM_K = "random_k"
|
| 26 |
+
FIRST_K = "first_k"
|
| 27 |
+
LAST_K = "last_k"
|
| 28 |
+
ALL = "all"
|
| 29 |
+
# fmt: on
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RandomKFramesSelector(Callable): # pyre-ignore[39]
|
| 33 |
+
"""
|
| 34 |
+
Selector that retains at most `k` random frames
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, k: int):
|
| 38 |
+
self.k = k
|
| 39 |
+
|
| 40 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
| 41 |
+
"""
|
| 42 |
+
Select `k` random frames
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
frames_tss (List[int]): timestamps of input frames
|
| 46 |
+
Returns:
|
| 47 |
+
List[int]: timestamps of selected frames
|
| 48 |
+
"""
|
| 49 |
+
return random.sample(frame_tss, min(self.k, len(frame_tss)))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FirstKFramesSelector(Callable): # pyre-ignore[39]
|
| 53 |
+
"""
|
| 54 |
+
Selector that retains at most `k` first frames
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, k: int):
|
| 58 |
+
self.k = k
|
| 59 |
+
|
| 60 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
| 61 |
+
"""
|
| 62 |
+
Select `k` first frames
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
frames_tss (List[int]): timestamps of input frames
|
| 66 |
+
Returns:
|
| 67 |
+
List[int]: timestamps of selected frames
|
| 68 |
+
"""
|
| 69 |
+
return frame_tss[: self.k]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LastKFramesSelector(Callable): # pyre-ignore[39]
|
| 73 |
+
"""
|
| 74 |
+
Selector that retains at most `k` last frames from video data
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, k: int):
|
| 78 |
+
self.k = k
|
| 79 |
+
|
| 80 |
+
def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
|
| 81 |
+
"""
|
| 82 |
+
Select `k` last frames
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
frames_tss (List[int]): timestamps of input frames
|
| 86 |
+
Returns:
|
| 87 |
+
List[int]: timestamps of selected frames
|
| 88 |
+
"""
|
| 89 |
+
return frame_tss[-self.k :]
|
densepose/data/video/video_keyframe_dataset.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
# pyre-unsafe
|
| 5 |
+
|
| 6 |
+
import csv
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 10 |
+
import av
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data.dataset import Dataset
|
| 13 |
+
|
| 14 |
+
from detectron2.utils.file_io import PathManager
|
| 15 |
+
|
| 16 |
+
from ..utils import maybe_prepend_base_path
|
| 17 |
+
from .frame_selector import FrameSelector, FrameTsList
|
| 18 |
+
|
| 19 |
+
FrameList = List[av.frame.Frame] # pyre-ignore[16]
|
| 20 |
+
FrameTransform = Callable[[torch.Tensor], torch.Tensor]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
|
| 24 |
+
"""
|
| 25 |
+
Traverses all keyframes of a video file. Returns a list of keyframe
|
| 26 |
+
timestamps. Timestamps are counts in timebase units.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
video_fpath (str): Video file path
|
| 30 |
+
video_stream_idx (int): Video stream index (default: 0)
|
| 31 |
+
Returns:
|
| 32 |
+
List[int]: list of keyframe timestaps (timestamp is a count in timebase
|
| 33 |
+
units)
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
with PathManager.open(video_fpath, "rb") as io:
|
| 37 |
+
# pyre-fixme[16]: Module `av` has no attribute `open`.
|
| 38 |
+
container = av.open(io, mode="r")
|
| 39 |
+
stream = container.streams.video[video_stream_idx]
|
| 40 |
+
keyframes = []
|
| 41 |
+
pts = -1
|
| 42 |
+
# Note: even though we request forward seeks for keyframes, sometimes
|
| 43 |
+
# a keyframe in backwards direction is returned. We introduce tolerance
|
| 44 |
+
# as a max count of ignored backward seeks
|
| 45 |
+
tolerance_backward_seeks = 2
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
|
| 49 |
+
except av.AVError as e:
|
| 50 |
+
# the exception occurs when the video length is exceeded,
|
| 51 |
+
# we then return whatever data we've already collected
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
logger.debug(
|
| 54 |
+
f"List keyframes: Error seeking video file {video_fpath}, "
|
| 55 |
+
f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
|
| 56 |
+
)
|
| 57 |
+
return keyframes
|
| 58 |
+
except OSError as e:
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
logger.warning(
|
| 61 |
+
f"List keyframes: Error seeking video file {video_fpath}, "
|
| 62 |
+
f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
|
| 63 |
+
)
|
| 64 |
+
return []
|
| 65 |
+
packet = next(container.demux(video=video_stream_idx))
|
| 66 |
+
if packet.pts is not None and packet.pts <= pts:
|
| 67 |
+
logger = logging.getLogger(__name__)
|
| 68 |
+
logger.warning(
|
| 69 |
+
f"Video file {video_fpath}, stream {video_stream_idx}: "
|
| 70 |
+
f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
|
| 71 |
+
f"tolerance {tolerance_backward_seeks}."
|
| 72 |
+
)
|
| 73 |
+
tolerance_backward_seeks -= 1
|
| 74 |
+
if tolerance_backward_seeks == 0:
|
| 75 |
+
return []
|
| 76 |
+
pts += 1
|
| 77 |
+
continue
|
| 78 |
+
tolerance_backward_seeks = 2
|
| 79 |
+
pts = packet.pts
|
| 80 |
+
if pts is None:
|
| 81 |
+
return keyframes
|
| 82 |
+
if packet.is_keyframe:
|
| 83 |
+
keyframes.append(pts)
|
| 84 |
+
return keyframes
|
| 85 |
+
except OSError as e:
|
| 86 |
+
logger = logging.getLogger(__name__)
|
| 87 |
+
logger.warning(
|
| 88 |
+
f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
|
| 89 |
+
)
|
| 90 |
+
except RuntimeError as e:
|
| 91 |
+
logger = logging.getLogger(__name__)
|
| 92 |
+
logger.warning(
|
| 93 |
+
f"List keyframes: Error opening video file container {video_fpath}, "
|
| 94 |
+
f"Runtime error: {e}"
|
| 95 |
+
)
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def read_keyframes(
|
| 100 |
+
video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
|
| 101 |
+
) -> FrameList: # pyre-ignore[11]
|
| 102 |
+
"""
|
| 103 |
+
Reads keyframe data from a video file.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
video_fpath (str): Video file path
|
| 107 |
+
keyframes (List[int]): List of keyframe timestamps (as counts in
|
| 108 |
+
timebase units to be used in container seek operations)
|
| 109 |
+
video_stream_idx (int): Video stream index (default: 0)
|
| 110 |
+
Returns:
|
| 111 |
+
List[Frame]: list of frames that correspond to the specified timestamps
|
| 112 |
+
"""
|
| 113 |
+
try:
|
| 114 |
+
with PathManager.open(video_fpath, "rb") as io:
|
| 115 |
+
# pyre-fixme[16]: Module `av` has no attribute `open`.
|
| 116 |
+
container = av.open(io)
|
| 117 |
+
stream = container.streams.video[video_stream_idx]
|
| 118 |
+
frames = []
|
| 119 |
+
for pts in keyframes:
|
| 120 |
+
try:
|
| 121 |
+
container.seek(pts, any_frame=False, stream=stream)
|
| 122 |
+
frame = next(container.decode(video=0))
|
| 123 |
+
frames.append(frame)
|
| 124 |
+
except av.AVError as e:
|
| 125 |
+
logger = logging.getLogger(__name__)
|
| 126 |
+
logger.warning(
|
| 127 |
+
f"Read keyframes: Error seeking video file {video_fpath}, "
|
| 128 |
+
f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
|
| 129 |
+
)
|
| 130 |
+
container.close()
|
| 131 |
+
return frames
|
| 132 |
+
except OSError as e:
|
| 133 |
+
logger = logging.getLogger(__name__)
|
| 134 |
+
logger.warning(
|
| 135 |
+
f"Read keyframes: Error seeking video file {video_fpath}, "
|
| 136 |
+
f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
|
| 137 |
+
)
|
| 138 |
+
container.close()
|
| 139 |
+
return frames
|
| 140 |
+
except StopIteration:
|
| 141 |
+
logger = logging.getLogger(__name__)
|
| 142 |
+
logger.warning(
|
| 143 |
+
f"Read keyframes: Error decoding frame from {video_fpath}, "
|
| 144 |
+
f"video stream {video_stream_idx}, pts {pts}"
|
| 145 |
+
)
|
| 146 |
+
container.close()
|
| 147 |
+
return frames
|
| 148 |
+
|
| 149 |
+
container.close()
|
| 150 |
+
return frames
|
| 151 |
+
except OSError as e:
|
| 152 |
+
logger = logging.getLogger(__name__)
|
| 153 |
+
logger.warning(
|
| 154 |
+
f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
|
| 155 |
+
)
|
| 156 |
+
except RuntimeError as e:
|
| 157 |
+
logger = logging.getLogger(__name__)
|
| 158 |
+
logger.warning(
|
| 159 |
+
f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
|
| 160 |
+
)
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
|
| 165 |
+
"""
|
| 166 |
+
Create a list of paths to video files from a text file.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
video_list_fpath (str): path to a plain text file with the list of videos
|
| 170 |
+
base_path (str): base path for entries from the video list (default: None)
|
| 171 |
+
"""
|
| 172 |
+
video_list = []
|
| 173 |
+
with PathManager.open(video_list_fpath, "r") as io:
|
| 174 |
+
for line in io:
|
| 175 |
+
video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
|
| 176 |
+
return video_list
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def read_keyframe_helper_data(fpath: str):
|
| 180 |
+
"""
|
| 181 |
+
Read keyframe data from a file in CSV format: the header should contain
|
| 182 |
+
"video_id" and "keyframes" fields. Value specifications are:
|
| 183 |
+
video_id: int
|
| 184 |
+
keyframes: list(int)
|
| 185 |
+
Example of contents:
|
| 186 |
+
video_id,keyframes
|
| 187 |
+
2,"[1,11,21,31,41,51,61,71,81]"
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
fpath (str): File containing keyframe data
|
| 191 |
+
|
| 192 |
+
Return:
|
| 193 |
+
video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
|
| 194 |
+
contains a list of keyframes for that video
|
| 195 |
+
"""
|
| 196 |
+
video_id_to_keyframes = {}
|
| 197 |
+
try:
|
| 198 |
+
with PathManager.open(fpath, "r") as io:
|
| 199 |
+
csv_reader = csv.reader(io)
|
| 200 |
+
header = next(csv_reader)
|
| 201 |
+
video_id_idx = header.index("video_id")
|
| 202 |
+
keyframes_idx = header.index("keyframes")
|
| 203 |
+
for row in csv_reader:
|
| 204 |
+
video_id = int(row[video_id_idx])
|
| 205 |
+
assert (
|
| 206 |
+
video_id not in video_id_to_keyframes
|
| 207 |
+
), f"Duplicate keyframes entry for video {fpath}"
|
| 208 |
+
video_id_to_keyframes[video_id] = (
|
| 209 |
+
[int(v) for v in row[keyframes_idx][1:-1].split(",")]
|
| 210 |
+
if len(row[keyframes_idx]) > 2
|
| 211 |
+
else []
|
| 212 |
+
)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger = logging.getLogger(__name__)
|
| 215 |
+
logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
|
| 216 |
+
return video_id_to_keyframes
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class VideoKeyframeDataset(Dataset):
|
| 220 |
+
"""
|
| 221 |
+
Dataset that provides keyframes for a set of videos.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
_EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
|
| 225 |
+
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
video_list: List[str],
|
| 229 |
+
category_list: Union[str, List[str], None] = None,
|
| 230 |
+
frame_selector: Optional[FrameSelector] = None,
|
| 231 |
+
transform: Optional[FrameTransform] = None,
|
| 232 |
+
keyframe_helper_fpath: Optional[str] = None,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Dataset constructor
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
video_list (List[str]): list of paths to video files
|
| 239 |
+
category_list (Union[str, List[str], None]): list of animal categories for each
|
| 240 |
+
video file. If it is a string, or None, this applies to all videos
|
| 241 |
+
frame_selector (Callable: KeyFrameList -> KeyFrameList):
|
| 242 |
+
selects keyframes to process, keyframes are given by
|
| 243 |
+
packet timestamps in timebase counts. If None, all keyframes
|
| 244 |
+
are selected (default: None)
|
| 245 |
+
transform (Callable: torch.Tensor -> torch.Tensor):
|
| 246 |
+
transforms a batch of RGB images (tensors of size [B, 3, H, W]),
|
| 247 |
+
returns a tensor of the same size. If None, no transform is
|
| 248 |
+
applied (default: None)
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
if type(category_list) is list:
|
| 252 |
+
self.category_list = category_list
|
| 253 |
+
else:
|
| 254 |
+
self.category_list = [category_list] * len(video_list)
|
| 255 |
+
assert len(video_list) == len(
|
| 256 |
+
self.category_list
|
| 257 |
+
), "length of video and category lists must be equal"
|
| 258 |
+
self.video_list = video_list
|
| 259 |
+
self.frame_selector = frame_selector
|
| 260 |
+
self.transform = transform
|
| 261 |
+
self.keyframe_helper_data = (
|
| 262 |
+
read_keyframe_helper_data(keyframe_helper_fpath)
|
| 263 |
+
if keyframe_helper_fpath is not None
|
| 264 |
+
else None
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 268 |
+
"""
|
| 269 |
+
Gets selected keyframes from a given video
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
idx (int): video index in the video list file
|
| 273 |
+
Returns:
|
| 274 |
+
A dictionary containing two keys:
|
| 275 |
+
images (torch.Tensor): tensor of size [N, H, W, 3] or of size
|
| 276 |
+
defined by the transform that contains keyframes data
|
| 277 |
+
categories (List[str]): categories of the frames
|
| 278 |
+
"""
|
| 279 |
+
categories = [self.category_list[idx]]
|
| 280 |
+
fpath = self.video_list[idx]
|
| 281 |
+
keyframes = (
|
| 282 |
+
list_keyframes(fpath)
|
| 283 |
+
if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
|
| 284 |
+
else self.keyframe_helper_data[idx]
|
| 285 |
+
)
|
| 286 |
+
transform = self.transform
|
| 287 |
+
frame_selector = self.frame_selector
|
| 288 |
+
if not keyframes:
|
| 289 |
+
return {"images": self._EMPTY_FRAMES, "categories": []}
|
| 290 |
+
if frame_selector is not None:
|
| 291 |
+
keyframes = frame_selector(keyframes)
|
| 292 |
+
frames = read_keyframes(fpath, keyframes)
|
| 293 |
+
if not frames:
|
| 294 |
+
return {"images": self._EMPTY_FRAMES, "categories": []}
|
| 295 |
+
frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
|
| 296 |
+
frames = torch.as_tensor(frames, device=torch.device("cpu"))
|
| 297 |
+
frames = frames[..., [2, 1, 0]] # RGB -> BGR
|
| 298 |
+
frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
|
| 299 |
+
if transform is not None:
|
| 300 |
+
frames = transform(frames)
|
| 301 |
+
return {"images": frames, "categories": categories}
|
| 302 |
+
|
| 303 |
+
def __len__(self):
|
| 304 |
+
return len(self.video_list)
|
densepose/engine/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .trainer import Trainer
|
densepose/engine/trainer.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from typing import List, Optional, Union
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 13 |
+
from detectron2.config import CfgNode
|
| 14 |
+
from detectron2.engine import DefaultTrainer
|
| 15 |
+
from detectron2.evaluation import (
|
| 16 |
+
DatasetEvaluator,
|
| 17 |
+
DatasetEvaluators,
|
| 18 |
+
inference_on_dataset,
|
| 19 |
+
print_csv_format,
|
| 20 |
+
)
|
| 21 |
+
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
|
| 22 |
+
from detectron2.utils import comm
|
| 23 |
+
from detectron2.utils.events import EventWriter, get_event_storage
|
| 24 |
+
|
| 25 |
+
from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
|
| 26 |
+
from densepose.data import (
|
| 27 |
+
DatasetMapper,
|
| 28 |
+
build_combined_loader,
|
| 29 |
+
build_detection_test_loader,
|
| 30 |
+
build_detection_train_loader,
|
| 31 |
+
build_inference_based_loaders,
|
| 32 |
+
has_inference_based_loaders,
|
| 33 |
+
)
|
| 34 |
+
from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
|
| 35 |
+
from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
|
| 36 |
+
from densepose.modeling.cse import Embedder
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SampleCountingLoader:
|
| 40 |
+
def __init__(self, loader):
|
| 41 |
+
self.loader = loader
|
| 42 |
+
|
| 43 |
+
def __iter__(self):
|
| 44 |
+
it = iter(self.loader)
|
| 45 |
+
storage = get_event_storage()
|
| 46 |
+
while True:
|
| 47 |
+
try:
|
| 48 |
+
batch = next(it)
|
| 49 |
+
num_inst_per_dataset = {}
|
| 50 |
+
for data in batch:
|
| 51 |
+
dataset_name = data["dataset"]
|
| 52 |
+
if dataset_name not in num_inst_per_dataset:
|
| 53 |
+
num_inst_per_dataset[dataset_name] = 0
|
| 54 |
+
num_inst = len(data["instances"])
|
| 55 |
+
num_inst_per_dataset[dataset_name] += num_inst
|
| 56 |
+
for dataset_name in num_inst_per_dataset:
|
| 57 |
+
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
| 58 |
+
yield batch
|
| 59 |
+
except StopIteration:
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SampleCountMetricPrinter(EventWriter):
|
| 64 |
+
def __init__(self):
|
| 65 |
+
self.logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
def write(self):
|
| 68 |
+
storage = get_event_storage()
|
| 69 |
+
batch_stats_strs = []
|
| 70 |
+
for key, buf in storage.histories().items():
|
| 71 |
+
if key.startswith("batch/"):
|
| 72 |
+
batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
| 73 |
+
self.logger.info(", ".join(batch_stats_strs))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Trainer(DefaultTrainer):
|
| 77 |
+
@classmethod
|
| 78 |
+
def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
|
| 79 |
+
if isinstance(model, nn.parallel.DistributedDataParallel):
|
| 80 |
+
model = model.module
|
| 81 |
+
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
|
| 82 |
+
return model.roi_heads.embedder
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
# TODO: the only reason to copy the base class code here is to pass the embedder from
|
| 86 |
+
# the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
|
| 87 |
+
@classmethod
|
| 88 |
+
def test(
|
| 89 |
+
cls,
|
| 90 |
+
cfg: CfgNode,
|
| 91 |
+
model: nn.Module,
|
| 92 |
+
evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
cfg (CfgNode):
|
| 97 |
+
model (nn.Module):
|
| 98 |
+
evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
|
| 99 |
+
:meth:`build_evaluator`. Otherwise, must have the same length as
|
| 100 |
+
``cfg.DATASETS.TEST``.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
dict: a dict of result metrics
|
| 104 |
+
"""
|
| 105 |
+
logger = logging.getLogger(__name__)
|
| 106 |
+
if isinstance(evaluators, DatasetEvaluator):
|
| 107 |
+
evaluators = [evaluators]
|
| 108 |
+
if evaluators is not None:
|
| 109 |
+
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
| 110 |
+
len(cfg.DATASETS.TEST), len(evaluators)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
results = OrderedDict()
|
| 114 |
+
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
| 115 |
+
data_loader = cls.build_test_loader(cfg, dataset_name)
|
| 116 |
+
# When evaluators are passed in as arguments,
|
| 117 |
+
# implicitly assume that evaluators can be created before data_loader.
|
| 118 |
+
if evaluators is not None:
|
| 119 |
+
evaluator = evaluators[idx]
|
| 120 |
+
else:
|
| 121 |
+
try:
|
| 122 |
+
embedder = cls.extract_embedder_from_model(model)
|
| 123 |
+
evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
|
| 124 |
+
except NotImplementedError:
|
| 125 |
+
logger.warn(
|
| 126 |
+
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
| 127 |
+
"or implement its `build_evaluator` method."
|
| 128 |
+
)
|
| 129 |
+
results[dataset_name] = {}
|
| 130 |
+
continue
|
| 131 |
+
if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
|
| 132 |
+
results_i = inference_on_dataset(model, data_loader, evaluator)
|
| 133 |
+
else:
|
| 134 |
+
results_i = {}
|
| 135 |
+
results[dataset_name] = results_i
|
| 136 |
+
if comm.is_main_process():
|
| 137 |
+
assert isinstance(
|
| 138 |
+
results_i, dict
|
| 139 |
+
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
| 140 |
+
results_i
|
| 141 |
+
)
|
| 142 |
+
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
| 143 |
+
print_csv_format(results_i)
|
| 144 |
+
|
| 145 |
+
if len(results) == 1:
|
| 146 |
+
results = list(results.values())[0]
|
| 147 |
+
return results
|
| 148 |
+
|
| 149 |
+
@classmethod
|
| 150 |
+
def build_evaluator(
|
| 151 |
+
cls,
|
| 152 |
+
cfg: CfgNode,
|
| 153 |
+
dataset_name: str,
|
| 154 |
+
output_folder: Optional[str] = None,
|
| 155 |
+
embedder: Optional[Embedder] = None,
|
| 156 |
+
) -> DatasetEvaluators:
|
| 157 |
+
if output_folder is None:
|
| 158 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
| 159 |
+
evaluators = []
|
| 160 |
+
distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
|
| 161 |
+
# Note: we currently use COCO evaluator for both COCO and LVIS datasets
|
| 162 |
+
# to have compatible metrics. LVIS bbox evaluator could also be used
|
| 163 |
+
# with an adapter to properly handle filtered / mapped categories
|
| 164 |
+
# evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
| 165 |
+
# if evaluator_type == "coco":
|
| 166 |
+
# evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
| 167 |
+
# elif evaluator_type == "lvis":
|
| 168 |
+
# evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
|
| 169 |
+
evaluators.append(
|
| 170 |
+
Detectron2COCOEvaluatorAdapter(
|
| 171 |
+
dataset_name, output_dir=output_folder, distributed=distributed
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
if cfg.MODEL.DENSEPOSE_ON:
|
| 175 |
+
storage = build_densepose_evaluator_storage(cfg, output_folder)
|
| 176 |
+
evaluators.append(
|
| 177 |
+
DensePoseCOCOEvaluator(
|
| 178 |
+
dataset_name,
|
| 179 |
+
distributed,
|
| 180 |
+
output_folder,
|
| 181 |
+
evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
|
| 182 |
+
min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
|
| 183 |
+
storage=storage,
|
| 184 |
+
embedder=embedder,
|
| 185 |
+
should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
|
| 186 |
+
mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
return DatasetEvaluators(evaluators)
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
|
| 193 |
+
params = get_default_optimizer_params(
|
| 194 |
+
model,
|
| 195 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
| 196 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
| 197 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
| 198 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
| 199 |
+
overrides={
|
| 200 |
+
"features": {
|
| 201 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
|
| 202 |
+
},
|
| 203 |
+
"embeddings": {
|
| 204 |
+
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
|
| 205 |
+
},
|
| 206 |
+
},
|
| 207 |
+
)
|
| 208 |
+
optimizer = torch.optim.SGD(
|
| 209 |
+
params,
|
| 210 |
+
cfg.SOLVER.BASE_LR,
|
| 211 |
+
momentum=cfg.SOLVER.MOMENTUM,
|
| 212 |
+
nesterov=cfg.SOLVER.NESTEROV,
|
| 213 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
| 214 |
+
)
|
| 215 |
+
# pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
|
| 216 |
+
return maybe_add_gradient_clipping(cfg, optimizer)
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
| 220 |
+
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def build_train_loader(cls, cfg: CfgNode):
|
| 224 |
+
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
| 225 |
+
if not has_inference_based_loaders(cfg):
|
| 226 |
+
return data_loader
|
| 227 |
+
model = cls.build_model(cfg)
|
| 228 |
+
model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
| 229 |
+
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
| 230 |
+
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
| 231 |
+
loaders = [data_loader] + inference_based_loaders
|
| 232 |
+
ratios = [1.0] + ratios
|
| 233 |
+
combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
| 234 |
+
sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
| 235 |
+
return sample_counting_loader
|
| 236 |
+
|
| 237 |
+
def build_writers(self):
|
| 238 |
+
writers = super().build_writers()
|
| 239 |
+
writers.append(SampleCountMetricPrinter())
|
| 240 |
+
return writers
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def test_with_TTA(cls, cfg: CfgNode, model):
|
| 244 |
+
logger = logging.getLogger("detectron2.trainer")
|
| 245 |
+
# In the end of training, run an evaluation with TTA
|
| 246 |
+
# Only support some R-CNN models.
|
| 247 |
+
logger.info("Running inference with test-time augmentation ...")
|
| 248 |
+
transform_data = load_from_cfg(cfg)
|
| 249 |
+
model = DensePoseGeneralizedRCNNWithTTA(
|
| 250 |
+
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
| 251 |
+
)
|
| 252 |
+
evaluators = [
|
| 253 |
+
cls.build_evaluator(
|
| 254 |
+
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
| 255 |
+
)
|
| 256 |
+
for name in cfg.DATASETS.TEST
|
| 257 |
+
]
|
| 258 |
+
res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
|
| 259 |
+
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
| 260 |
+
return res
|
densepose/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .evaluator import DensePoseCOCOEvaluator
|
densepose/evaluation/d2_evaluator_adapter.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from detectron2.data.catalog import Metadata
|
| 6 |
+
from detectron2.evaluation import COCOEvaluator
|
| 7 |
+
|
| 8 |
+
from densepose.data.datasets.coco import (
|
| 9 |
+
get_contiguous_id_to_category_id_map,
|
| 10 |
+
maybe_filter_categories_cocoapi,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _maybe_add_iscrowd_annotations(cocoapi) -> None:
|
| 15 |
+
for ann in cocoapi.dataset["annotations"]:
|
| 16 |
+
if "iscrowd" not in ann:
|
| 17 |
+
ann["iscrowd"] = 0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
dataset_name,
|
| 24 |
+
output_dir=None,
|
| 25 |
+
distributed=True,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
|
| 28 |
+
maybe_filter_categories_cocoapi(dataset_name, self._coco_api)
|
| 29 |
+
_maybe_add_iscrowd_annotations(self._coco_api)
|
| 30 |
+
# substitute category metadata to account for categories
|
| 31 |
+
# that are mapped to the same contiguous id
|
| 32 |
+
if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
|
| 33 |
+
self._maybe_substitute_metadata()
|
| 34 |
+
|
| 35 |
+
def _maybe_substitute_metadata(self):
|
| 36 |
+
cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata)
|
| 37 |
+
cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id
|
| 38 |
+
if len(cont_id_2_cat_id) == len(cat_id_2_cont_id):
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
cat_id_2_cont_id_injective = {}
|
| 42 |
+
for cat_id, cont_id in cat_id_2_cont_id.items():
|
| 43 |
+
if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
|
| 44 |
+
cat_id_2_cont_id_injective[cat_id] = cont_id
|
| 45 |
+
|
| 46 |
+
metadata_new = Metadata(name=self._metadata.name)
|
| 47 |
+
for key, value in self._metadata.__dict__.items():
|
| 48 |
+
if key == "thing_dataset_id_to_contiguous_id":
|
| 49 |
+
setattr(metadata_new, key, cat_id_2_cont_id_injective)
|
| 50 |
+
else:
|
| 51 |
+
setattr(metadata_new, key, value)
|
| 52 |
+
self._metadata = metadata_new
|
densepose/evaluation/densepose_coco_evaluation.py
ADDED
|
@@ -0,0 +1,1305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# This is a modified version of cocoeval.py where we also have the densepose evaluation.
|
| 7 |
+
|
| 8 |
+
# pyre-unsafe
|
| 9 |
+
|
| 10 |
+
__author__ = "tsungyi"
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
import datetime
|
| 14 |
+
import logging
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pickle
|
| 17 |
+
import time
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import Any, Dict, Tuple
|
| 21 |
+
import scipy.spatial.distance as ssd
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from pycocotools import mask as maskUtils
|
| 25 |
+
from scipy.io import loadmat
|
| 26 |
+
from scipy.ndimage import zoom as spzoom
|
| 27 |
+
|
| 28 |
+
from detectron2.utils.file_io import PathManager
|
| 29 |
+
|
| 30 |
+
from densepose.converters.chart_output_to_chart_result import resample_uv_tensors_to_bbox
|
| 31 |
+
from densepose.converters.segm_to_mask import (
|
| 32 |
+
resample_coarse_segm_tensor_to_bbox,
|
| 33 |
+
resample_fine_and_coarse_segm_tensors_to_bbox,
|
| 34 |
+
)
|
| 35 |
+
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
|
| 36 |
+
from densepose.structures import DensePoseDataRelative
|
| 37 |
+
from densepose.structures.mesh import create_mesh
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DensePoseEvalMode(str, Enum):
|
| 43 |
+
# use both masks and geodesic distances (GPS * IOU) to compute scores
|
| 44 |
+
GPSM = "gpsm"
|
| 45 |
+
# use only geodesic distances (GPS) to compute scores
|
| 46 |
+
GPS = "gps"
|
| 47 |
+
# use only masks (IOU) to compute scores
|
| 48 |
+
IOU = "iou"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DensePoseDataMode(str, Enum):
|
| 52 |
+
# use estimated IUV data (default mode)
|
| 53 |
+
IUV_DT = "iuvdt"
|
| 54 |
+
# use ground truth IUV data
|
| 55 |
+
IUV_GT = "iuvgt"
|
| 56 |
+
# use ground truth labels I and set UV to 0
|
| 57 |
+
I_GT_UV_0 = "igtuv0"
|
| 58 |
+
# use ground truth labels I and estimated UV coordinates
|
| 59 |
+
I_GT_UV_DT = "igtuvdt"
|
| 60 |
+
# use estimated labels I and set UV to 0
|
| 61 |
+
I_DT_UV_0 = "idtuv0"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DensePoseCocoEval:
|
| 65 |
+
# Interface for evaluating detection on the Microsoft COCO dataset.
|
| 66 |
+
#
|
| 67 |
+
# The usage for CocoEval is as follows:
|
| 68 |
+
# cocoGt=..., cocoDt=... # load dataset and results
|
| 69 |
+
# E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
|
| 70 |
+
# E.params.recThrs = ...; # set parameters as desired
|
| 71 |
+
# E.evaluate(); # run per image evaluation
|
| 72 |
+
# E.accumulate(); # accumulate per image results
|
| 73 |
+
# E.summarize(); # display summary metrics of results
|
| 74 |
+
# For example usage see evalDemo.m and http://mscoco.org/.
|
| 75 |
+
#
|
| 76 |
+
# The evaluation parameters are as follows (defaults in brackets):
|
| 77 |
+
# imgIds - [all] N img ids to use for evaluation
|
| 78 |
+
# catIds - [all] K cat ids to use for evaluation
|
| 79 |
+
# iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
|
| 80 |
+
# recThrs - [0:.01:1] R=101 recall thresholds for evaluation
|
| 81 |
+
# areaRng - [...] A=4 object area ranges for evaluation
|
| 82 |
+
# maxDets - [1 10 100] M=3 thresholds on max detections per image
|
| 83 |
+
# iouType - ['segm'] set iouType to 'segm', 'bbox', 'keypoints' or 'densepose'
|
| 84 |
+
# iouType replaced the now DEPRECATED useSegm parameter.
|
| 85 |
+
# useCats - [1] if true use category labels for evaluation
|
| 86 |
+
# Note: if useCats=0 category labels are ignored as in proposal scoring.
|
| 87 |
+
# Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
|
| 88 |
+
#
|
| 89 |
+
# evaluate(): evaluates detections on every image and every category and
|
| 90 |
+
# concats the results into the "evalImgs" with fields:
|
| 91 |
+
# dtIds - [1xD] id for each of the D detections (dt)
|
| 92 |
+
# gtIds - [1xG] id for each of the G ground truths (gt)
|
| 93 |
+
# dtMatches - [TxD] matching gt id at each IoU or 0
|
| 94 |
+
# gtMatches - [TxG] matching dt id at each IoU or 0
|
| 95 |
+
# dtScores - [1xD] confidence of each dt
|
| 96 |
+
# gtIgnore - [1xG] ignore flag for each gt
|
| 97 |
+
# dtIgnore - [TxD] ignore flag for each dt at each IoU
|
| 98 |
+
#
|
| 99 |
+
# accumulate(): accumulates the per-image, per-category evaluation
|
| 100 |
+
# results in "evalImgs" into the dictionary "eval" with fields:
|
| 101 |
+
# params - parameters used for evaluation
|
| 102 |
+
# date - date evaluation was performed
|
| 103 |
+
# counts - [T,R,K,A,M] parameter dimensions (see above)
|
| 104 |
+
# precision - [TxRxKxAxM] precision for every evaluation setting
|
| 105 |
+
# recall - [TxKxAxM] max recall for every evaluation setting
|
| 106 |
+
# Note: precision and recall==-1 for settings with no gt objects.
|
| 107 |
+
#
|
| 108 |
+
# See also coco, mask, pycocoDemo, pycocoEvalDemo
|
| 109 |
+
#
|
| 110 |
+
# Microsoft COCO Toolbox. version 2.0
|
| 111 |
+
# Data, paper, and tutorials available at: http://mscoco.org/
|
| 112 |
+
# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
|
| 113 |
+
# Licensed under the Simplified BSD License [see coco/license.txt]
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
cocoGt=None,
|
| 117 |
+
cocoDt=None,
|
| 118 |
+
iouType: str = "densepose",
|
| 119 |
+
multi_storage=None,
|
| 120 |
+
embedder=None,
|
| 121 |
+
dpEvalMode: DensePoseEvalMode = DensePoseEvalMode.GPS,
|
| 122 |
+
dpDataMode: DensePoseDataMode = DensePoseDataMode.IUV_DT,
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Initialize CocoEval using coco APIs for gt and dt
|
| 126 |
+
:param cocoGt: coco object with ground truth annotations
|
| 127 |
+
:param cocoDt: coco object with detection results
|
| 128 |
+
:return: None
|
| 129 |
+
"""
|
| 130 |
+
self.cocoGt = cocoGt # ground truth COCO API
|
| 131 |
+
self.cocoDt = cocoDt # detections COCO API
|
| 132 |
+
self.multi_storage = multi_storage
|
| 133 |
+
self.embedder = embedder
|
| 134 |
+
self._dpEvalMode = dpEvalMode
|
| 135 |
+
self._dpDataMode = dpDataMode
|
| 136 |
+
self.evalImgs = defaultdict(list) # per-image per-category eval results [KxAxI]
|
| 137 |
+
self.eval = {} # accumulated evaluation results
|
| 138 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 139 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 140 |
+
self.params = Params(iouType=iouType) # parameters
|
| 141 |
+
self._paramsEval = {} # parameters for evaluation
|
| 142 |
+
self.stats = [] # result summarization
|
| 143 |
+
self.ious = {} # ious between all gts and dts
|
| 144 |
+
if cocoGt is not None:
|
| 145 |
+
self.params.imgIds = sorted(cocoGt.getImgIds())
|
| 146 |
+
self.params.catIds = sorted(cocoGt.getCatIds())
|
| 147 |
+
self.ignoreThrBB = 0.7
|
| 148 |
+
self.ignoreThrUV = 0.9
|
| 149 |
+
|
| 150 |
+
def _loadGEval(self):
|
| 151 |
+
smpl_subdiv_fpath = PathManager.get_local_path(
|
| 152 |
+
"https://dl.fbaipublicfiles.com/densepose/data/SMPL_subdiv.mat"
|
| 153 |
+
)
|
| 154 |
+
pdist_transform_fpath = PathManager.get_local_path(
|
| 155 |
+
"https://dl.fbaipublicfiles.com/densepose/data/SMPL_SUBDIV_TRANSFORM.mat"
|
| 156 |
+
)
|
| 157 |
+
pdist_matrix_fpath = PathManager.get_local_path(
|
| 158 |
+
"https://dl.fbaipublicfiles.com/densepose/data/Pdist_matrix.pkl", timeout_sec=120
|
| 159 |
+
)
|
| 160 |
+
SMPL_subdiv = loadmat(smpl_subdiv_fpath)
|
| 161 |
+
self.PDIST_transform = loadmat(pdist_transform_fpath)
|
| 162 |
+
self.PDIST_transform = self.PDIST_transform["index"].squeeze()
|
| 163 |
+
UV = np.array([SMPL_subdiv["U_subdiv"], SMPL_subdiv["V_subdiv"]]).squeeze()
|
| 164 |
+
ClosestVertInds = np.arange(UV.shape[1]) + 1
|
| 165 |
+
self.Part_UVs = []
|
| 166 |
+
self.Part_ClosestVertInds = []
|
| 167 |
+
for i in np.arange(24):
|
| 168 |
+
self.Part_UVs.append(UV[:, SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)])
|
| 169 |
+
self.Part_ClosestVertInds.append(
|
| 170 |
+
ClosestVertInds[SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
with open(pdist_matrix_fpath, "rb") as hFile:
|
| 174 |
+
arrays = pickle.load(hFile, encoding="latin1")
|
| 175 |
+
self.Pdist_matrix = arrays["Pdist_matrix"]
|
| 176 |
+
self.Part_ids = np.array(SMPL_subdiv["Part_ID_subdiv"].squeeze())
|
| 177 |
+
# Mean geodesic distances for parts.
|
| 178 |
+
self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150])
|
| 179 |
+
# Coarse Part labels.
|
| 180 |
+
self.CoarseParts = np.array(
|
| 181 |
+
[0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def _prepare(self):
|
| 185 |
+
"""
|
| 186 |
+
Prepare ._gts and ._dts for evaluation based on params
|
| 187 |
+
:return: None
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def _toMask(anns, coco):
|
| 191 |
+
# modify ann['segmentation'] by reference
|
| 192 |
+
for ann in anns:
|
| 193 |
+
# safeguard for invalid segmentation annotation;
|
| 194 |
+
# annotations containing empty lists exist in the posetrack
|
| 195 |
+
# dataset. This is not a correct segmentation annotation
|
| 196 |
+
# in terms of COCO format; we need to deal with it somehow
|
| 197 |
+
segm = ann["segmentation"]
|
| 198 |
+
if type(segm) is list and len(segm) == 0:
|
| 199 |
+
ann["segmentation"] = None
|
| 200 |
+
continue
|
| 201 |
+
rle = coco.annToRLE(ann)
|
| 202 |
+
ann["segmentation"] = rle
|
| 203 |
+
|
| 204 |
+
def _getIgnoreRegion(iid, coco):
|
| 205 |
+
img = coco.imgs[iid]
|
| 206 |
+
|
| 207 |
+
if "ignore_regions_x" not in img.keys():
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
if len(img["ignore_regions_x"]) == 0:
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
rgns_merged = [
|
| 214 |
+
[v for xy in zip(region_x, region_y) for v in xy]
|
| 215 |
+
for region_x, region_y in zip(img["ignore_regions_x"], img["ignore_regions_y"])
|
| 216 |
+
]
|
| 217 |
+
rles = maskUtils.frPyObjects(rgns_merged, img["height"], img["width"])
|
| 218 |
+
rle = maskUtils.merge(rles)
|
| 219 |
+
return maskUtils.decode(rle)
|
| 220 |
+
|
| 221 |
+
def _checkIgnore(dt, iregion):
|
| 222 |
+
if iregion is None:
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
bb = np.array(dt["bbox"]).astype(int)
|
| 226 |
+
x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3]
|
| 227 |
+
x2 = min([x2, iregion.shape[1]])
|
| 228 |
+
y2 = min([y2, iregion.shape[0]])
|
| 229 |
+
|
| 230 |
+
if bb[2] * bb[3] == 0:
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
crop_iregion = iregion[y1:y2, x1:x2]
|
| 234 |
+
|
| 235 |
+
if crop_iregion.sum() == 0:
|
| 236 |
+
return True
|
| 237 |
+
|
| 238 |
+
if "densepose" not in dt.keys(): # filtering boxes
|
| 239 |
+
return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB
|
| 240 |
+
|
| 241 |
+
# filtering UVs
|
| 242 |
+
ignoremask = np.require(crop_iregion, requirements=["F"])
|
| 243 |
+
mask = self._extract_mask(dt)
|
| 244 |
+
uvmask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
|
| 245 |
+
uvmask_ = maskUtils.encode(uvmask)
|
| 246 |
+
ignoremask_ = maskUtils.encode(ignoremask)
|
| 247 |
+
uviou = maskUtils.iou([uvmask_], [ignoremask_], [1])[0]
|
| 248 |
+
return uviou < self.ignoreThrUV
|
| 249 |
+
|
| 250 |
+
p = self.params
|
| 251 |
+
|
| 252 |
+
if p.useCats:
|
| 253 |
+
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
| 254 |
+
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
| 255 |
+
else:
|
| 256 |
+
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
| 257 |
+
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
| 258 |
+
|
| 259 |
+
imns = self.cocoGt.loadImgs(p.imgIds)
|
| 260 |
+
self.size_mapping = {}
|
| 261 |
+
for im in imns:
|
| 262 |
+
self.size_mapping[im["id"]] = [im["height"], im["width"]]
|
| 263 |
+
|
| 264 |
+
# if iouType == 'uv', add point gt annotations
|
| 265 |
+
if p.iouType == "densepose":
|
| 266 |
+
self._loadGEval()
|
| 267 |
+
|
| 268 |
+
# convert ground truth to mask if iouType == 'segm'
|
| 269 |
+
if p.iouType == "segm":
|
| 270 |
+
_toMask(gts, self.cocoGt)
|
| 271 |
+
_toMask(dts, self.cocoDt)
|
| 272 |
+
|
| 273 |
+
# set ignore flag
|
| 274 |
+
for gt in gts:
|
| 275 |
+
gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
|
| 276 |
+
gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
|
| 277 |
+
if p.iouType == "keypoints":
|
| 278 |
+
gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
|
| 279 |
+
if p.iouType == "densepose":
|
| 280 |
+
gt["ignore"] = ("dp_x" in gt) == 0
|
| 281 |
+
if p.iouType == "segm":
|
| 282 |
+
gt["ignore"] = gt["segmentation"] is None
|
| 283 |
+
|
| 284 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 285 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 286 |
+
self._igrgns = defaultdict(list)
|
| 287 |
+
|
| 288 |
+
for gt in gts:
|
| 289 |
+
iid = gt["image_id"]
|
| 290 |
+
if iid not in self._igrgns.keys():
|
| 291 |
+
self._igrgns[iid] = _getIgnoreRegion(iid, self.cocoGt)
|
| 292 |
+
if _checkIgnore(gt, self._igrgns[iid]):
|
| 293 |
+
self._gts[iid, gt["category_id"]].append(gt)
|
| 294 |
+
for dt in dts:
|
| 295 |
+
iid = dt["image_id"]
|
| 296 |
+
if (iid not in self._igrgns) or _checkIgnore(dt, self._igrgns[iid]):
|
| 297 |
+
self._dts[iid, dt["category_id"]].append(dt)
|
| 298 |
+
|
| 299 |
+
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
|
| 300 |
+
self.eval = {} # accumulated evaluation results
|
| 301 |
+
|
| 302 |
+
def evaluate(self):
|
| 303 |
+
"""
|
| 304 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 305 |
+
:return: None
|
| 306 |
+
"""
|
| 307 |
+
tic = time.time()
|
| 308 |
+
logger.info("Running per image DensePose evaluation... {}".format(self.params.iouType))
|
| 309 |
+
p = self.params
|
| 310 |
+
# add backward compatibility if useSegm is specified in params
|
| 311 |
+
if p.useSegm is not None:
|
| 312 |
+
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
| 313 |
+
logger.info("useSegm (deprecated) is not None. Running DensePose evaluation")
|
| 314 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 315 |
+
if p.useCats:
|
| 316 |
+
p.catIds = list(np.unique(p.catIds))
|
| 317 |
+
p.maxDets = sorted(p.maxDets)
|
| 318 |
+
self.params = p
|
| 319 |
+
|
| 320 |
+
self._prepare()
|
| 321 |
+
# loop through images, area range, max detection number
|
| 322 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 323 |
+
|
| 324 |
+
if p.iouType in ["segm", "bbox"]:
|
| 325 |
+
computeIoU = self.computeIoU
|
| 326 |
+
elif p.iouType == "keypoints":
|
| 327 |
+
computeIoU = self.computeOks
|
| 328 |
+
elif p.iouType == "densepose":
|
| 329 |
+
computeIoU = self.computeOgps
|
| 330 |
+
if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
|
| 331 |
+
self.real_ious = {
|
| 332 |
+
(imgId, catId): self.computeDPIoU(imgId, catId)
|
| 333 |
+
for imgId in p.imgIds
|
| 334 |
+
for catId in catIds
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
self.ious = {
|
| 338 |
+
(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
evaluateImg = self.evaluateImg
|
| 342 |
+
maxDet = p.maxDets[-1]
|
| 343 |
+
self.evalImgs = [
|
| 344 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
| 345 |
+
for catId in catIds
|
| 346 |
+
for areaRng in p.areaRng
|
| 347 |
+
for imgId in p.imgIds
|
| 348 |
+
]
|
| 349 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 350 |
+
toc = time.time()
|
| 351 |
+
logger.info("DensePose evaluation DONE (t={:0.2f}s).".format(toc - tic))
|
| 352 |
+
|
| 353 |
+
def getDensePoseMask(self, polys):
|
| 354 |
+
maskGen = np.zeros([256, 256])
|
| 355 |
+
stop = min(len(polys) + 1, 15)
|
| 356 |
+
for i in range(1, stop):
|
| 357 |
+
if polys[i - 1]:
|
| 358 |
+
currentMask = maskUtils.decode(polys[i - 1])
|
| 359 |
+
maskGen[currentMask > 0] = i
|
| 360 |
+
return maskGen
|
| 361 |
+
|
| 362 |
+
def _generate_rlemask_on_image(self, mask, imgId, data):
|
| 363 |
+
bbox_xywh = np.array(data["bbox"])
|
| 364 |
+
x, y, w, h = bbox_xywh
|
| 365 |
+
im_h, im_w = self.size_mapping[imgId]
|
| 366 |
+
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
|
| 367 |
+
if mask is not None:
|
| 368 |
+
x0 = max(int(x), 0)
|
| 369 |
+
x1 = min(int(x + w), im_w, int(x) + mask.shape[1])
|
| 370 |
+
y0 = max(int(y), 0)
|
| 371 |
+
y1 = min(int(y + h), im_h, int(y) + mask.shape[0])
|
| 372 |
+
y = int(y)
|
| 373 |
+
x = int(x)
|
| 374 |
+
im_mask[y0:y1, x0:x1] = mask[y0 - y : y1 - y, x0 - x : x1 - x]
|
| 375 |
+
im_mask = np.require(np.asarray(im_mask > 0), dtype=np.uint8, requirements=["F"])
|
| 376 |
+
rle_mask = maskUtils.encode(np.array(im_mask[:, :, np.newaxis], order="F"))[0]
|
| 377 |
+
return rle_mask
|
| 378 |
+
|
| 379 |
+
def computeDPIoU(self, imgId, catId):
|
| 380 |
+
p = self.params
|
| 381 |
+
if p.useCats:
|
| 382 |
+
gt = self._gts[imgId, catId]
|
| 383 |
+
dt = self._dts[imgId, catId]
|
| 384 |
+
else:
|
| 385 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 386 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 387 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 388 |
+
return []
|
| 389 |
+
inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 390 |
+
dt = [dt[i] for i in inds]
|
| 391 |
+
if len(dt) > p.maxDets[-1]:
|
| 392 |
+
dt = dt[0 : p.maxDets[-1]]
|
| 393 |
+
|
| 394 |
+
gtmasks = []
|
| 395 |
+
for g in gt:
|
| 396 |
+
if DensePoseDataRelative.S_KEY in g:
|
| 397 |
+
# convert DensePose mask to a binary mask
|
| 398 |
+
mask = np.minimum(self.getDensePoseMask(g[DensePoseDataRelative.S_KEY]), 1.0)
|
| 399 |
+
_, _, w, h = g["bbox"]
|
| 400 |
+
scale_x = float(max(w, 1)) / mask.shape[1]
|
| 401 |
+
scale_y = float(max(h, 1)) / mask.shape[0]
|
| 402 |
+
mask = spzoom(mask, (scale_y, scale_x), order=1, prefilter=False)
|
| 403 |
+
mask = np.array(mask > 0.5, dtype=np.uint8)
|
| 404 |
+
rle_mask = self._generate_rlemask_on_image(mask, imgId, g)
|
| 405 |
+
elif "segmentation" in g:
|
| 406 |
+
segmentation = g["segmentation"]
|
| 407 |
+
if isinstance(segmentation, list) and segmentation:
|
| 408 |
+
# polygons
|
| 409 |
+
im_h, im_w = self.size_mapping[imgId]
|
| 410 |
+
rles = maskUtils.frPyObjects(segmentation, im_h, im_w)
|
| 411 |
+
rle_mask = maskUtils.merge(rles)
|
| 412 |
+
elif isinstance(segmentation, dict):
|
| 413 |
+
if isinstance(segmentation["counts"], list):
|
| 414 |
+
# uncompressed RLE
|
| 415 |
+
im_h, im_w = self.size_mapping[imgId]
|
| 416 |
+
rle_mask = maskUtils.frPyObjects(segmentation, im_h, im_w)
|
| 417 |
+
else:
|
| 418 |
+
# compressed RLE
|
| 419 |
+
rle_mask = segmentation
|
| 420 |
+
else:
|
| 421 |
+
rle_mask = self._generate_rlemask_on_image(None, imgId, g)
|
| 422 |
+
else:
|
| 423 |
+
rle_mask = self._generate_rlemask_on_image(None, imgId, g)
|
| 424 |
+
gtmasks.append(rle_mask)
|
| 425 |
+
|
| 426 |
+
dtmasks = []
|
| 427 |
+
for d in dt:
|
| 428 |
+
mask = self._extract_mask(d)
|
| 429 |
+
mask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
|
| 430 |
+
rle_mask = self._generate_rlemask_on_image(mask, imgId, d)
|
| 431 |
+
dtmasks.append(rle_mask)
|
| 432 |
+
|
| 433 |
+
# compute iou between each dt and gt region
|
| 434 |
+
iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
|
| 435 |
+
iousDP = maskUtils.iou(dtmasks, gtmasks, iscrowd)
|
| 436 |
+
return iousDP
|
| 437 |
+
|
| 438 |
+
def computeIoU(self, imgId, catId):
|
| 439 |
+
p = self.params
|
| 440 |
+
if p.useCats:
|
| 441 |
+
gt = self._gts[imgId, catId]
|
| 442 |
+
dt = self._dts[imgId, catId]
|
| 443 |
+
else:
|
| 444 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 445 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 446 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 447 |
+
return []
|
| 448 |
+
inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 449 |
+
dt = [dt[i] for i in inds]
|
| 450 |
+
if len(dt) > p.maxDets[-1]:
|
| 451 |
+
dt = dt[0 : p.maxDets[-1]]
|
| 452 |
+
|
| 453 |
+
if p.iouType == "segm":
|
| 454 |
+
g = [g["segmentation"] for g in gt if g["segmentation"] is not None]
|
| 455 |
+
d = [d["segmentation"] for d in dt if d["segmentation"] is not None]
|
| 456 |
+
elif p.iouType == "bbox":
|
| 457 |
+
g = [g["bbox"] for g in gt]
|
| 458 |
+
d = [d["bbox"] for d in dt]
|
| 459 |
+
else:
|
| 460 |
+
raise Exception("unknown iouType for iou computation")
|
| 461 |
+
|
| 462 |
+
# compute iou between each dt and gt region
|
| 463 |
+
iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
|
| 464 |
+
ious = maskUtils.iou(d, g, iscrowd)
|
| 465 |
+
return ious
|
| 466 |
+
|
| 467 |
+
def computeOks(self, imgId, catId):
|
| 468 |
+
p = self.params
|
| 469 |
+
# dimension here should be Nxm
|
| 470 |
+
gts = self._gts[imgId, catId]
|
| 471 |
+
dts = self._dts[imgId, catId]
|
| 472 |
+
inds = np.argsort([-d["score"] for d in dts], kind="mergesort")
|
| 473 |
+
dts = [dts[i] for i in inds]
|
| 474 |
+
if len(dts) > p.maxDets[-1]:
|
| 475 |
+
dts = dts[0 : p.maxDets[-1]]
|
| 476 |
+
# if len(gts) == 0 and len(dts) == 0:
|
| 477 |
+
if len(gts) == 0 or len(dts) == 0:
|
| 478 |
+
return []
|
| 479 |
+
ious = np.zeros((len(dts), len(gts)))
|
| 480 |
+
sigmas = (
|
| 481 |
+
np.array(
|
| 482 |
+
[
|
| 483 |
+
0.26,
|
| 484 |
+
0.25,
|
| 485 |
+
0.25,
|
| 486 |
+
0.35,
|
| 487 |
+
0.35,
|
| 488 |
+
0.79,
|
| 489 |
+
0.79,
|
| 490 |
+
0.72,
|
| 491 |
+
0.72,
|
| 492 |
+
0.62,
|
| 493 |
+
0.62,
|
| 494 |
+
1.07,
|
| 495 |
+
1.07,
|
| 496 |
+
0.87,
|
| 497 |
+
0.87,
|
| 498 |
+
0.89,
|
| 499 |
+
0.89,
|
| 500 |
+
]
|
| 501 |
+
)
|
| 502 |
+
/ 10.0
|
| 503 |
+
)
|
| 504 |
+
vars = (sigmas * 2) ** 2
|
| 505 |
+
k = len(sigmas)
|
| 506 |
+
# compute oks between each detection and ground truth object
|
| 507 |
+
for j, gt in enumerate(gts):
|
| 508 |
+
# create bounds for ignore regions(double the gt bbox)
|
| 509 |
+
g = np.array(gt["keypoints"])
|
| 510 |
+
xg = g[0::3]
|
| 511 |
+
yg = g[1::3]
|
| 512 |
+
vg = g[2::3]
|
| 513 |
+
k1 = np.count_nonzero(vg > 0)
|
| 514 |
+
bb = gt["bbox"]
|
| 515 |
+
x0 = bb[0] - bb[2]
|
| 516 |
+
x1 = bb[0] + bb[2] * 2
|
| 517 |
+
y0 = bb[1] - bb[3]
|
| 518 |
+
y1 = bb[1] + bb[3] * 2
|
| 519 |
+
for i, dt in enumerate(dts):
|
| 520 |
+
d = np.array(dt["keypoints"])
|
| 521 |
+
xd = d[0::3]
|
| 522 |
+
yd = d[1::3]
|
| 523 |
+
if k1 > 0:
|
| 524 |
+
# measure the per-keypoint distance if keypoints visible
|
| 525 |
+
dx = xd - xg
|
| 526 |
+
dy = yd - yg
|
| 527 |
+
else:
|
| 528 |
+
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
|
| 529 |
+
z = np.zeros(k)
|
| 530 |
+
dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
|
| 531 |
+
dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
|
| 532 |
+
e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2
|
| 533 |
+
if k1 > 0:
|
| 534 |
+
e = e[vg > 0]
|
| 535 |
+
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
|
| 536 |
+
return ious
|
| 537 |
+
|
| 538 |
+
def _extract_mask(self, dt: Dict[str, Any]) -> np.ndarray:
|
| 539 |
+
if "densepose" in dt:
|
| 540 |
+
densepose_results_quantized = dt["densepose"]
|
| 541 |
+
return densepose_results_quantized.labels_uv_uint8[0].numpy()
|
| 542 |
+
elif "cse_mask" in dt:
|
| 543 |
+
return dt["cse_mask"]
|
| 544 |
+
elif "coarse_segm" in dt:
|
| 545 |
+
dy = max(int(dt["bbox"][3]), 1)
|
| 546 |
+
dx = max(int(dt["bbox"][2]), 1)
|
| 547 |
+
return (
|
| 548 |
+
F.interpolate(
|
| 549 |
+
dt["coarse_segm"].unsqueeze(0),
|
| 550 |
+
(dy, dx),
|
| 551 |
+
mode="bilinear",
|
| 552 |
+
align_corners=False,
|
| 553 |
+
)
|
| 554 |
+
.squeeze(0)
|
| 555 |
+
.argmax(0)
|
| 556 |
+
.numpy()
|
| 557 |
+
.astype(np.uint8)
|
| 558 |
+
)
|
| 559 |
+
elif "record_id" in dt:
|
| 560 |
+
assert (
|
| 561 |
+
self.multi_storage is not None
|
| 562 |
+
), f"Storage record id encountered in a detection {dt}, but no storage provided!"
|
| 563 |
+
record = self.multi_storage.get(dt["rank"], dt["record_id"])
|
| 564 |
+
coarse_segm = record["coarse_segm"]
|
| 565 |
+
dy = max(int(dt["bbox"][3]), 1)
|
| 566 |
+
dx = max(int(dt["bbox"][2]), 1)
|
| 567 |
+
return (
|
| 568 |
+
F.interpolate(
|
| 569 |
+
coarse_segm.unsqueeze(0),
|
| 570 |
+
(dy, dx),
|
| 571 |
+
mode="bilinear",
|
| 572 |
+
align_corners=False,
|
| 573 |
+
)
|
| 574 |
+
.squeeze(0)
|
| 575 |
+
.argmax(0)
|
| 576 |
+
.numpy()
|
| 577 |
+
.astype(np.uint8)
|
| 578 |
+
)
|
| 579 |
+
else:
|
| 580 |
+
raise Exception(f"No mask data in the detection: {dt}")
|
| 581 |
+
raise ValueError('The prediction dict needs to contain either "densepose" or "cse_mask"')
|
| 582 |
+
|
| 583 |
+
def _extract_iuv(
|
| 584 |
+
self, densepose_data: np.ndarray, py: np.ndarray, px: np.ndarray, gt: Dict[str, Any]
|
| 585 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 586 |
+
"""
|
| 587 |
+
Extract arrays of I, U and V values at given points as numpy arrays
|
| 588 |
+
given the data mode stored in self._dpDataMode
|
| 589 |
+
"""
|
| 590 |
+
if self._dpDataMode == DensePoseDataMode.IUV_DT:
|
| 591 |
+
# estimated labels and UV (default)
|
| 592 |
+
ipoints = densepose_data[0, py, px]
|
| 593 |
+
upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
|
| 594 |
+
vpoints = densepose_data[2, py, px] / 255.0
|
| 595 |
+
elif self._dpDataMode == DensePoseDataMode.IUV_GT:
|
| 596 |
+
# ground truth
|
| 597 |
+
ipoints = np.array(gt["dp_I"])
|
| 598 |
+
upoints = np.array(gt["dp_U"])
|
| 599 |
+
vpoints = np.array(gt["dp_V"])
|
| 600 |
+
elif self._dpDataMode == DensePoseDataMode.I_GT_UV_0:
|
| 601 |
+
# ground truth labels, UV = 0
|
| 602 |
+
ipoints = np.array(gt["dp_I"])
|
| 603 |
+
upoints = upoints * 0.0
|
| 604 |
+
vpoints = vpoints * 0.0
|
| 605 |
+
elif self._dpDataMode == DensePoseDataMode.I_GT_UV_DT:
|
| 606 |
+
# ground truth labels, estimated UV
|
| 607 |
+
ipoints = np.array(gt["dp_I"])
|
| 608 |
+
upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
|
| 609 |
+
vpoints = densepose_data[2, py, px] / 255.0
|
| 610 |
+
elif self._dpDataMode == DensePoseDataMode.I_DT_UV_0:
|
| 611 |
+
# estimated labels, UV = 0
|
| 612 |
+
ipoints = densepose_data[0, py, px]
|
| 613 |
+
upoints = upoints * 0.0
|
| 614 |
+
vpoints = vpoints * 0.0
|
| 615 |
+
else:
|
| 616 |
+
raise ValueError(f"Unknown data mode: {self._dpDataMode}")
|
| 617 |
+
return ipoints, upoints, vpoints
|
| 618 |
+
|
| 619 |
+
def computeOgps_single_pair(self, dt, gt, py, px, pt_mask):
|
| 620 |
+
if "densepose" in dt:
|
| 621 |
+
ipoints, upoints, vpoints = self.extract_iuv_from_quantized(dt, gt, py, px, pt_mask)
|
| 622 |
+
return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
|
| 623 |
+
elif "u" in dt:
|
| 624 |
+
ipoints, upoints, vpoints = self.extract_iuv_from_raw(dt, gt, py, px, pt_mask)
|
| 625 |
+
return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
|
| 626 |
+
elif "record_id" in dt:
|
| 627 |
+
assert (
|
| 628 |
+
self.multi_storage is not None
|
| 629 |
+
), f"Storage record id encountered in detection {dt}, but no storage provided!"
|
| 630 |
+
record = self.multi_storage.get(dt["rank"], dt["record_id"])
|
| 631 |
+
record["bbox"] = dt["bbox"]
|
| 632 |
+
if "u" in record:
|
| 633 |
+
ipoints, upoints, vpoints = self.extract_iuv_from_raw(record, gt, py, px, pt_mask)
|
| 634 |
+
return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
|
| 635 |
+
elif "embedding" in record:
|
| 636 |
+
return self.computeOgps_single_pair_cse(
|
| 637 |
+
dt,
|
| 638 |
+
gt,
|
| 639 |
+
py,
|
| 640 |
+
px,
|
| 641 |
+
pt_mask,
|
| 642 |
+
record["coarse_segm"],
|
| 643 |
+
record["embedding"],
|
| 644 |
+
record["bbox"],
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
raise Exception(f"Unknown record format: {record}")
|
| 648 |
+
elif "embedding" in dt:
|
| 649 |
+
return self.computeOgps_single_pair_cse(
|
| 650 |
+
dt, gt, py, px, pt_mask, dt["coarse_segm"], dt["embedding"], dt["bbox"]
|
| 651 |
+
)
|
| 652 |
+
raise Exception(f"Unknown detection format: {dt}")
|
| 653 |
+
|
| 654 |
+
def extract_iuv_from_quantized(self, dt, gt, py, px, pt_mask):
|
| 655 |
+
densepose_results_quantized = dt["densepose"]
|
| 656 |
+
ipoints, upoints, vpoints = self._extract_iuv(
|
| 657 |
+
densepose_results_quantized.labels_uv_uint8.numpy(), py, px, gt
|
| 658 |
+
)
|
| 659 |
+
ipoints[pt_mask == -1] = 0
|
| 660 |
+
return ipoints, upoints, vpoints
|
| 661 |
+
|
| 662 |
+
def extract_iuv_from_raw(self, dt, gt, py, px, pt_mask):
|
| 663 |
+
labels_dt = resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 664 |
+
dt["fine_segm"].unsqueeze(0),
|
| 665 |
+
dt["coarse_segm"].unsqueeze(0),
|
| 666 |
+
dt["bbox"],
|
| 667 |
+
)
|
| 668 |
+
uv = resample_uv_tensors_to_bbox(
|
| 669 |
+
dt["u"].unsqueeze(0), dt["v"].unsqueeze(0), labels_dt.squeeze(0), dt["bbox"]
|
| 670 |
+
)
|
| 671 |
+
labels_uv_uint8 = torch.cat((labels_dt.byte(), (uv * 255).clamp(0, 255).byte()))
|
| 672 |
+
ipoints, upoints, vpoints = self._extract_iuv(labels_uv_uint8.numpy(), py, px, gt)
|
| 673 |
+
ipoints[pt_mask == -1] = 0
|
| 674 |
+
return ipoints, upoints, vpoints
|
| 675 |
+
|
| 676 |
+
def computeOgps_single_pair_iuv(self, dt, gt, ipoints, upoints, vpoints):
|
| 677 |
+
cVertsGT, ClosestVertsGTTransformed = self.findAllClosestVertsGT(gt)
|
| 678 |
+
cVerts = self.findAllClosestVertsUV(upoints, vpoints, ipoints)
|
| 679 |
+
# Get pairwise geodesic distances between gt and estimated mesh points.
|
| 680 |
+
dist = self.getDistancesUV(ClosestVertsGTTransformed, cVerts)
|
| 681 |
+
# Compute the Ogps measure.
|
| 682 |
+
# Find the mean geodesic normalization distance for
|
| 683 |
+
# each GT point, based on which part it is on.
|
| 684 |
+
Current_Mean_Distances = self.Mean_Distances[
|
| 685 |
+
self.CoarseParts[self.Part_ids[cVertsGT[cVertsGT > 0].astype(int) - 1]]
|
| 686 |
+
]
|
| 687 |
+
return dist, Current_Mean_Distances
|
| 688 |
+
|
| 689 |
+
def computeOgps_single_pair_cse(
|
| 690 |
+
self, dt, gt, py, px, pt_mask, coarse_segm, embedding, bbox_xywh_abs
|
| 691 |
+
):
|
| 692 |
+
# 0-based mesh vertex indices
|
| 693 |
+
cVertsGT = torch.as_tensor(gt["dp_vertex"], dtype=torch.int64)
|
| 694 |
+
# label for each pixel of the bbox, [H, W] tensor of long
|
| 695 |
+
labels_dt = resample_coarse_segm_tensor_to_bbox(
|
| 696 |
+
coarse_segm.unsqueeze(0), bbox_xywh_abs
|
| 697 |
+
).squeeze(0)
|
| 698 |
+
x, y, w, h = bbox_xywh_abs
|
| 699 |
+
# embedding for each pixel of the bbox, [D, H, W] tensor of float32
|
| 700 |
+
embedding = F.interpolate(
|
| 701 |
+
embedding.unsqueeze(0), (int(h), int(w)), mode="bilinear", align_corners=False
|
| 702 |
+
).squeeze(0)
|
| 703 |
+
# valid locations py, px
|
| 704 |
+
py_pt = torch.from_numpy(py[pt_mask > -1])
|
| 705 |
+
px_pt = torch.from_numpy(px[pt_mask > -1])
|
| 706 |
+
cVerts = torch.ones_like(cVertsGT) * -1
|
| 707 |
+
cVerts[pt_mask > -1] = self.findClosestVertsCse(
|
| 708 |
+
embedding, py_pt, px_pt, labels_dt, gt["ref_model"]
|
| 709 |
+
)
|
| 710 |
+
# Get pairwise geodesic distances between gt and estimated mesh points.
|
| 711 |
+
dist = self.getDistancesCse(cVertsGT, cVerts, gt["ref_model"])
|
| 712 |
+
# normalize distances
|
| 713 |
+
if (gt["ref_model"] == "smpl_27554") and ("dp_I" in gt):
|
| 714 |
+
Current_Mean_Distances = self.Mean_Distances[
|
| 715 |
+
self.CoarseParts[np.array(gt["dp_I"], dtype=int)]
|
| 716 |
+
]
|
| 717 |
+
else:
|
| 718 |
+
Current_Mean_Distances = 0.255
|
| 719 |
+
return dist, Current_Mean_Distances
|
| 720 |
+
|
| 721 |
+
def computeOgps(self, imgId, catId):
|
| 722 |
+
p = self.params
|
| 723 |
+
# dimension here should be Nxm
|
| 724 |
+
g = self._gts[imgId, catId]
|
| 725 |
+
d = self._dts[imgId, catId]
|
| 726 |
+
inds = np.argsort([-d_["score"] for d_ in d], kind="mergesort")
|
| 727 |
+
d = [d[i] for i in inds]
|
| 728 |
+
if len(d) > p.maxDets[-1]:
|
| 729 |
+
d = d[0 : p.maxDets[-1]]
|
| 730 |
+
# if len(gts) == 0 and len(dts) == 0:
|
| 731 |
+
if len(g) == 0 or len(d) == 0:
|
| 732 |
+
return []
|
| 733 |
+
ious = np.zeros((len(d), len(g)))
|
| 734 |
+
# compute opgs between each detection and ground truth object
|
| 735 |
+
# sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5
|
| 736 |
+
# 1 # dist = 0.3m corresponds to ogps = 0.96
|
| 737 |
+
# 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5)
|
| 738 |
+
for j, gt in enumerate(g):
|
| 739 |
+
if not gt["ignore"]:
|
| 740 |
+
g_ = gt["bbox"]
|
| 741 |
+
for i, dt in enumerate(d):
|
| 742 |
+
#
|
| 743 |
+
dy = int(dt["bbox"][3])
|
| 744 |
+
dx = int(dt["bbox"][2])
|
| 745 |
+
dp_x = np.array(gt["dp_x"]) * g_[2] / 255.0
|
| 746 |
+
dp_y = np.array(gt["dp_y"]) * g_[3] / 255.0
|
| 747 |
+
py = (dp_y + g_[1] - dt["bbox"][1]).astype(int)
|
| 748 |
+
px = (dp_x + g_[0] - dt["bbox"][0]).astype(int)
|
| 749 |
+
#
|
| 750 |
+
pts = np.zeros(len(px))
|
| 751 |
+
pts[px >= dx] = -1
|
| 752 |
+
pts[py >= dy] = -1
|
| 753 |
+
pts[px < 0] = -1
|
| 754 |
+
pts[py < 0] = -1
|
| 755 |
+
if len(pts) < 1:
|
| 756 |
+
ogps = 0.0
|
| 757 |
+
elif np.max(pts) == -1:
|
| 758 |
+
ogps = 0.0
|
| 759 |
+
else:
|
| 760 |
+
px[pts == -1] = 0
|
| 761 |
+
py[pts == -1] = 0
|
| 762 |
+
dists_between_matches, dist_norm_coeffs = self.computeOgps_single_pair(
|
| 763 |
+
dt, gt, py, px, pts
|
| 764 |
+
)
|
| 765 |
+
# Compute gps
|
| 766 |
+
ogps_values = np.exp(
|
| 767 |
+
-(dists_between_matches**2) / (2 * (dist_norm_coeffs**2))
|
| 768 |
+
)
|
| 769 |
+
#
|
| 770 |
+
ogps = np.mean(ogps_values) if len(ogps_values) > 0 else 0.0
|
| 771 |
+
ious[i, j] = ogps
|
| 772 |
+
|
| 773 |
+
gbb = [gt["bbox"] for gt in g]
|
| 774 |
+
dbb = [dt["bbox"] for dt in d]
|
| 775 |
+
|
| 776 |
+
# compute iou between each dt and gt region
|
| 777 |
+
iscrowd = [int(o.get("iscrowd", 0)) for o in g]
|
| 778 |
+
ious_bb = maskUtils.iou(dbb, gbb, iscrowd)
|
| 779 |
+
return ious, ious_bb
|
| 780 |
+
|
| 781 |
+
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
| 782 |
+
"""
|
| 783 |
+
perform evaluation for single category and image
|
| 784 |
+
:return: dict (single image results)
|
| 785 |
+
"""
|
| 786 |
+
|
| 787 |
+
p = self.params
|
| 788 |
+
if p.useCats:
|
| 789 |
+
gt = self._gts[imgId, catId]
|
| 790 |
+
dt = self._dts[imgId, catId]
|
| 791 |
+
else:
|
| 792 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 793 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 794 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 795 |
+
return None
|
| 796 |
+
|
| 797 |
+
for g in gt:
|
| 798 |
+
# g['_ignore'] = g['ignore']
|
| 799 |
+
if g["ignore"] or (g["area"] < aRng[0] or g["area"] > aRng[1]):
|
| 800 |
+
g["_ignore"] = True
|
| 801 |
+
else:
|
| 802 |
+
g["_ignore"] = False
|
| 803 |
+
|
| 804 |
+
# sort dt highest score first, sort gt ignore last
|
| 805 |
+
gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
|
| 806 |
+
gt = [gt[i] for i in gtind]
|
| 807 |
+
dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 808 |
+
dt = [dt[i] for i in dtind[0:maxDet]]
|
| 809 |
+
iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
|
| 810 |
+
# load computed ious
|
| 811 |
+
if p.iouType == "densepose":
|
| 812 |
+
# print('Checking the length', len(self.ious[imgId, catId]))
|
| 813 |
+
# if len(self.ious[imgId, catId]) == 0:
|
| 814 |
+
# print(self.ious[imgId, catId])
|
| 815 |
+
ious = (
|
| 816 |
+
self.ious[imgId, catId][0][:, gtind]
|
| 817 |
+
if len(self.ious[imgId, catId]) > 0
|
| 818 |
+
else self.ious[imgId, catId]
|
| 819 |
+
)
|
| 820 |
+
ioubs = (
|
| 821 |
+
self.ious[imgId, catId][1][:, gtind]
|
| 822 |
+
if len(self.ious[imgId, catId]) > 0
|
| 823 |
+
else self.ious[imgId, catId]
|
| 824 |
+
)
|
| 825 |
+
if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
|
| 826 |
+
iousM = (
|
| 827 |
+
self.real_ious[imgId, catId][:, gtind]
|
| 828 |
+
if len(self.real_ious[imgId, catId]) > 0
|
| 829 |
+
else self.real_ious[imgId, catId]
|
| 830 |
+
)
|
| 831 |
+
else:
|
| 832 |
+
ious = (
|
| 833 |
+
self.ious[imgId, catId][:, gtind]
|
| 834 |
+
if len(self.ious[imgId, catId]) > 0
|
| 835 |
+
else self.ious[imgId, catId]
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
T = len(p.iouThrs)
|
| 839 |
+
G = len(gt)
|
| 840 |
+
D = len(dt)
|
| 841 |
+
gtm = np.zeros((T, G))
|
| 842 |
+
dtm = np.zeros((T, D))
|
| 843 |
+
gtIg = np.array([g["_ignore"] for g in gt])
|
| 844 |
+
dtIg = np.zeros((T, D))
|
| 845 |
+
if np.all(gtIg) and p.iouType == "densepose":
|
| 846 |
+
dtIg = np.logical_or(dtIg, True)
|
| 847 |
+
|
| 848 |
+
if len(ious) > 0: # and not p.iouType == 'densepose':
|
| 849 |
+
for tind, t in enumerate(p.iouThrs):
|
| 850 |
+
for dind, d in enumerate(dt):
|
| 851 |
+
# information about best match so far (m=-1 -> unmatched)
|
| 852 |
+
iou = min([t, 1 - 1e-10])
|
| 853 |
+
m = -1
|
| 854 |
+
for gind, _g in enumerate(gt):
|
| 855 |
+
# if this gt already matched, and not a crowd, continue
|
| 856 |
+
if gtm[tind, gind] > 0 and not iscrowd[gind]:
|
| 857 |
+
continue
|
| 858 |
+
# if dt matched to reg gt, and on ignore gt, stop
|
| 859 |
+
if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
|
| 860 |
+
break
|
| 861 |
+
if p.iouType == "densepose":
|
| 862 |
+
if self._dpEvalMode == DensePoseEvalMode.GPSM:
|
| 863 |
+
new_iou = np.sqrt(iousM[dind, gind] * ious[dind, gind])
|
| 864 |
+
elif self._dpEvalMode == DensePoseEvalMode.IOU:
|
| 865 |
+
new_iou = iousM[dind, gind]
|
| 866 |
+
elif self._dpEvalMode == DensePoseEvalMode.GPS:
|
| 867 |
+
new_iou = ious[dind, gind]
|
| 868 |
+
else:
|
| 869 |
+
new_iou = ious[dind, gind]
|
| 870 |
+
if new_iou < iou:
|
| 871 |
+
continue
|
| 872 |
+
if new_iou == 0.0:
|
| 873 |
+
continue
|
| 874 |
+
# if match successful and best so far, store appropriately
|
| 875 |
+
iou = new_iou
|
| 876 |
+
m = gind
|
| 877 |
+
# if match made store id of match for both dt and gt
|
| 878 |
+
if m == -1:
|
| 879 |
+
continue
|
| 880 |
+
dtIg[tind, dind] = gtIg[m]
|
| 881 |
+
dtm[tind, dind] = gt[m]["id"]
|
| 882 |
+
gtm[tind, m] = d["id"]
|
| 883 |
+
|
| 884 |
+
if p.iouType == "densepose":
|
| 885 |
+
if not len(ioubs) == 0:
|
| 886 |
+
for dind, d in enumerate(dt):
|
| 887 |
+
# information about best match so far (m=-1 -> unmatched)
|
| 888 |
+
if dtm[tind, dind] == 0:
|
| 889 |
+
ioub = 0.8
|
| 890 |
+
m = -1
|
| 891 |
+
for gind, _g in enumerate(gt):
|
| 892 |
+
# if this gt already matched, and not a crowd, continue
|
| 893 |
+
if gtm[tind, gind] > 0 and not iscrowd[gind]:
|
| 894 |
+
continue
|
| 895 |
+
# continue to next gt unless better match made
|
| 896 |
+
if ioubs[dind, gind] < ioub:
|
| 897 |
+
continue
|
| 898 |
+
# if match successful and best so far, store appropriately
|
| 899 |
+
ioub = ioubs[dind, gind]
|
| 900 |
+
m = gind
|
| 901 |
+
# if match made store id of match for both dt and gt
|
| 902 |
+
if m > -1:
|
| 903 |
+
dtIg[:, dind] = gtIg[m]
|
| 904 |
+
if gtIg[m]:
|
| 905 |
+
dtm[tind, dind] = gt[m]["id"]
|
| 906 |
+
gtm[tind, m] = d["id"]
|
| 907 |
+
# set unmatched detections outside of area range to ignore
|
| 908 |
+
a = np.array([d["area"] < aRng[0] or d["area"] > aRng[1] for d in dt]).reshape((1, len(dt)))
|
| 909 |
+
dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0)))
|
| 910 |
+
# store results for given image and category
|
| 911 |
+
# print('Done with the function', len(self.ious[imgId, catId]))
|
| 912 |
+
return {
|
| 913 |
+
"image_id": imgId,
|
| 914 |
+
"category_id": catId,
|
| 915 |
+
"aRng": aRng,
|
| 916 |
+
"maxDet": maxDet,
|
| 917 |
+
"dtIds": [d["id"] for d in dt],
|
| 918 |
+
"gtIds": [g["id"] for g in gt],
|
| 919 |
+
"dtMatches": dtm,
|
| 920 |
+
"gtMatches": gtm,
|
| 921 |
+
"dtScores": [d["score"] for d in dt],
|
| 922 |
+
"gtIgnore": gtIg,
|
| 923 |
+
"dtIgnore": dtIg,
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
def accumulate(self, p=None):
|
| 927 |
+
"""
|
| 928 |
+
Accumulate per image evaluation results and store the result in self.eval
|
| 929 |
+
:param p: input params for evaluation
|
| 930 |
+
:return: None
|
| 931 |
+
"""
|
| 932 |
+
logger.info("Accumulating evaluation results...")
|
| 933 |
+
tic = time.time()
|
| 934 |
+
if not self.evalImgs:
|
| 935 |
+
logger.info("Please run evaluate() first")
|
| 936 |
+
# allows input customized parameters
|
| 937 |
+
if p is None:
|
| 938 |
+
p = self.params
|
| 939 |
+
p.catIds = p.catIds if p.useCats == 1 else [-1]
|
| 940 |
+
T = len(p.iouThrs)
|
| 941 |
+
R = len(p.recThrs)
|
| 942 |
+
K = len(p.catIds) if p.useCats else 1
|
| 943 |
+
A = len(p.areaRng)
|
| 944 |
+
M = len(p.maxDets)
|
| 945 |
+
precision = -(np.ones((T, R, K, A, M))) # -1 for the precision of absent categories
|
| 946 |
+
recall = -(np.ones((T, K, A, M)))
|
| 947 |
+
|
| 948 |
+
# create dictionary for future indexing
|
| 949 |
+
logger.info("Categories: {}".format(p.catIds))
|
| 950 |
+
_pe = self._paramsEval
|
| 951 |
+
catIds = _pe.catIds if _pe.useCats else [-1]
|
| 952 |
+
setK = set(catIds)
|
| 953 |
+
setA = set(map(tuple, _pe.areaRng))
|
| 954 |
+
setM = set(_pe.maxDets)
|
| 955 |
+
setI = set(_pe.imgIds)
|
| 956 |
+
# get inds to evaluate
|
| 957 |
+
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
|
| 958 |
+
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
|
| 959 |
+
a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
|
| 960 |
+
i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
|
| 961 |
+
I0 = len(_pe.imgIds)
|
| 962 |
+
A0 = len(_pe.areaRng)
|
| 963 |
+
# retrieve E at each category, area range, and max number of detections
|
| 964 |
+
for k, k0 in enumerate(k_list):
|
| 965 |
+
Nk = k0 * A0 * I0
|
| 966 |
+
for a, a0 in enumerate(a_list):
|
| 967 |
+
Na = a0 * I0
|
| 968 |
+
for m, maxDet in enumerate(m_list):
|
| 969 |
+
E = [self.evalImgs[Nk + Na + i] for i in i_list]
|
| 970 |
+
E = [e for e in E if e is not None]
|
| 971 |
+
if len(E) == 0:
|
| 972 |
+
continue
|
| 973 |
+
dtScores = np.concatenate([e["dtScores"][0:maxDet] for e in E])
|
| 974 |
+
|
| 975 |
+
# different sorting method generates slightly different results.
|
| 976 |
+
# mergesort is used to be consistent as Matlab implementation.
|
| 977 |
+
inds = np.argsort(-dtScores, kind="mergesort")
|
| 978 |
+
|
| 979 |
+
dtm = np.concatenate([e["dtMatches"][:, 0:maxDet] for e in E], axis=1)[:, inds]
|
| 980 |
+
dtIg = np.concatenate([e["dtIgnore"][:, 0:maxDet] for e in E], axis=1)[:, inds]
|
| 981 |
+
gtIg = np.concatenate([e["gtIgnore"] for e in E])
|
| 982 |
+
npig = np.count_nonzero(gtIg == 0)
|
| 983 |
+
if npig == 0:
|
| 984 |
+
continue
|
| 985 |
+
tps = np.logical_and(dtm, np.logical_not(dtIg))
|
| 986 |
+
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
|
| 987 |
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
|
| 988 |
+
fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
|
| 989 |
+
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
| 990 |
+
tp = np.array(tp)
|
| 991 |
+
fp = np.array(fp)
|
| 992 |
+
nd = len(tp)
|
| 993 |
+
rc = tp / npig
|
| 994 |
+
pr = tp / (fp + tp + np.spacing(1))
|
| 995 |
+
q = np.zeros((R,))
|
| 996 |
+
|
| 997 |
+
if nd:
|
| 998 |
+
recall[t, k, a, m] = rc[-1]
|
| 999 |
+
else:
|
| 1000 |
+
recall[t, k, a, m] = 0
|
| 1001 |
+
|
| 1002 |
+
# numpy is slow without cython optimization for accessing elements
|
| 1003 |
+
# use python array gets significant speed improvement
|
| 1004 |
+
pr = pr.tolist()
|
| 1005 |
+
q = q.tolist()
|
| 1006 |
+
|
| 1007 |
+
for i in range(nd - 1, 0, -1):
|
| 1008 |
+
if pr[i] > pr[i - 1]:
|
| 1009 |
+
pr[i - 1] = pr[i]
|
| 1010 |
+
|
| 1011 |
+
inds = np.searchsorted(rc, p.recThrs, side="left")
|
| 1012 |
+
try:
|
| 1013 |
+
for ri, pi in enumerate(inds):
|
| 1014 |
+
q[ri] = pr[pi]
|
| 1015 |
+
except Exception:
|
| 1016 |
+
pass
|
| 1017 |
+
precision[t, :, k, a, m] = np.array(q)
|
| 1018 |
+
logger.info(
|
| 1019 |
+
"Final: max precision {}, min precision {}".format(np.max(precision), np.min(precision))
|
| 1020 |
+
)
|
| 1021 |
+
self.eval = {
|
| 1022 |
+
"params": p,
|
| 1023 |
+
"counts": [T, R, K, A, M],
|
| 1024 |
+
"date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 1025 |
+
"precision": precision,
|
| 1026 |
+
"recall": recall,
|
| 1027 |
+
}
|
| 1028 |
+
toc = time.time()
|
| 1029 |
+
logger.info("DONE (t={:0.2f}s).".format(toc - tic))
|
| 1030 |
+
|
| 1031 |
+
def summarize(self):
|
| 1032 |
+
"""
|
| 1033 |
+
Compute and display summary metrics for evaluation results.
|
| 1034 |
+
Note this function can *only* be applied on the default parameter setting
|
| 1035 |
+
"""
|
| 1036 |
+
|
| 1037 |
+
def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
|
| 1038 |
+
p = self.params
|
| 1039 |
+
iStr = " {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 1040 |
+
titleStr = "Average Precision" if ap == 1 else "Average Recall"
|
| 1041 |
+
typeStr = "(AP)" if ap == 1 else "(AR)"
|
| 1042 |
+
measure = "IoU"
|
| 1043 |
+
if self.params.iouType == "keypoints":
|
| 1044 |
+
measure = "OKS"
|
| 1045 |
+
elif self.params.iouType == "densepose":
|
| 1046 |
+
measure = "OGPS"
|
| 1047 |
+
iouStr = (
|
| 1048 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
| 1049 |
+
if iouThr is None
|
| 1050 |
+
else "{:0.2f}".format(iouThr)
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
| 1054 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
| 1055 |
+
if ap == 1:
|
| 1056 |
+
# dimension of precision: [TxRxKxAxM]
|
| 1057 |
+
s = self.eval["precision"]
|
| 1058 |
+
# IoU
|
| 1059 |
+
if iouThr is not None:
|
| 1060 |
+
t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
|
| 1061 |
+
s = s[t]
|
| 1062 |
+
s = s[:, :, :, aind, mind]
|
| 1063 |
+
else:
|
| 1064 |
+
# dimension of recall: [TxKxAxM]
|
| 1065 |
+
s = self.eval["recall"]
|
| 1066 |
+
if iouThr is not None:
|
| 1067 |
+
t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
|
| 1068 |
+
s = s[t]
|
| 1069 |
+
s = s[:, :, aind, mind]
|
| 1070 |
+
if len(s[s > -1]) == 0:
|
| 1071 |
+
mean_s = -1
|
| 1072 |
+
else:
|
| 1073 |
+
mean_s = np.mean(s[s > -1])
|
| 1074 |
+
logger.info(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s))
|
| 1075 |
+
return mean_s
|
| 1076 |
+
|
| 1077 |
+
def _summarizeDets():
|
| 1078 |
+
stats = np.zeros((12,))
|
| 1079 |
+
stats[0] = _summarize(1)
|
| 1080 |
+
stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
|
| 1081 |
+
stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
|
| 1082 |
+
stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
|
| 1083 |
+
stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
|
| 1084 |
+
stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
|
| 1085 |
+
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
| 1086 |
+
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
| 1087 |
+
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
| 1088 |
+
stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
|
| 1089 |
+
stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
|
| 1090 |
+
stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
|
| 1091 |
+
return stats
|
| 1092 |
+
|
| 1093 |
+
def _summarizeKps():
|
| 1094 |
+
stats = np.zeros((10,))
|
| 1095 |
+
stats[0] = _summarize(1, maxDets=20)
|
| 1096 |
+
stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
|
| 1097 |
+
stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
|
| 1098 |
+
stats[3] = _summarize(1, maxDets=20, areaRng="medium")
|
| 1099 |
+
stats[4] = _summarize(1, maxDets=20, areaRng="large")
|
| 1100 |
+
stats[5] = _summarize(0, maxDets=20)
|
| 1101 |
+
stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
|
| 1102 |
+
stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
|
| 1103 |
+
stats[8] = _summarize(0, maxDets=20, areaRng="medium")
|
| 1104 |
+
stats[9] = _summarize(0, maxDets=20, areaRng="large")
|
| 1105 |
+
return stats
|
| 1106 |
+
|
| 1107 |
+
def _summarizeUvs():
|
| 1108 |
+
stats = [_summarize(1, maxDets=self.params.maxDets[0])]
|
| 1109 |
+
min_threshold = self.params.iouThrs.min()
|
| 1110 |
+
if min_threshold <= 0.201:
|
| 1111 |
+
stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.2)]
|
| 1112 |
+
if min_threshold <= 0.301:
|
| 1113 |
+
stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.3)]
|
| 1114 |
+
if min_threshold <= 0.401:
|
| 1115 |
+
stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.4)]
|
| 1116 |
+
stats += [
|
| 1117 |
+
_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5),
|
| 1118 |
+
_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75),
|
| 1119 |
+
_summarize(1, maxDets=self.params.maxDets[0], areaRng="medium"),
|
| 1120 |
+
_summarize(1, maxDets=self.params.maxDets[0], areaRng="large"),
|
| 1121 |
+
_summarize(0, maxDets=self.params.maxDets[0]),
|
| 1122 |
+
_summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5),
|
| 1123 |
+
_summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75),
|
| 1124 |
+
_summarize(0, maxDets=self.params.maxDets[0], areaRng="medium"),
|
| 1125 |
+
_summarize(0, maxDets=self.params.maxDets[0], areaRng="large"),
|
| 1126 |
+
]
|
| 1127 |
+
return np.array(stats)
|
| 1128 |
+
|
| 1129 |
+
def _summarizeUvsOld():
|
| 1130 |
+
stats = np.zeros((18,))
|
| 1131 |
+
stats[0] = _summarize(1, maxDets=self.params.maxDets[0])
|
| 1132 |
+
stats[1] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5)
|
| 1133 |
+
stats[2] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.55)
|
| 1134 |
+
stats[3] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.60)
|
| 1135 |
+
stats[4] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.65)
|
| 1136 |
+
stats[5] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.70)
|
| 1137 |
+
stats[6] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75)
|
| 1138 |
+
stats[7] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.80)
|
| 1139 |
+
stats[8] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.85)
|
| 1140 |
+
stats[9] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.90)
|
| 1141 |
+
stats[10] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.95)
|
| 1142 |
+
stats[11] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium")
|
| 1143 |
+
stats[12] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="large")
|
| 1144 |
+
stats[13] = _summarize(0, maxDets=self.params.maxDets[0])
|
| 1145 |
+
stats[14] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5)
|
| 1146 |
+
stats[15] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75)
|
| 1147 |
+
stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium")
|
| 1148 |
+
stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="large")
|
| 1149 |
+
return stats
|
| 1150 |
+
|
| 1151 |
+
if not self.eval:
|
| 1152 |
+
raise Exception("Please run accumulate() first")
|
| 1153 |
+
iouType = self.params.iouType
|
| 1154 |
+
if iouType in ["segm", "bbox"]:
|
| 1155 |
+
summarize = _summarizeDets
|
| 1156 |
+
elif iouType in ["keypoints"]:
|
| 1157 |
+
summarize = _summarizeKps
|
| 1158 |
+
elif iouType in ["densepose"]:
|
| 1159 |
+
summarize = _summarizeUvs
|
| 1160 |
+
self.stats = summarize()
|
| 1161 |
+
|
| 1162 |
+
def __str__(self):
|
| 1163 |
+
self.summarize()
|
| 1164 |
+
|
| 1165 |
+
# ================ functions for dense pose ==============================
|
| 1166 |
+
def findAllClosestVertsUV(self, U_points, V_points, Index_points):
|
| 1167 |
+
ClosestVerts = np.ones(Index_points.shape) * -1
|
| 1168 |
+
for i in np.arange(24):
|
| 1169 |
+
#
|
| 1170 |
+
if (i + 1) in Index_points:
|
| 1171 |
+
UVs = np.array(
|
| 1172 |
+
[U_points[Index_points == (i + 1)], V_points[Index_points == (i + 1)]]
|
| 1173 |
+
)
|
| 1174 |
+
Current_Part_UVs = self.Part_UVs[i]
|
| 1175 |
+
Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
|
| 1176 |
+
D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
|
| 1177 |
+
ClosestVerts[Index_points == (i + 1)] = Current_Part_ClosestVertInds[
|
| 1178 |
+
np.argmin(D, axis=0)
|
| 1179 |
+
]
|
| 1180 |
+
ClosestVertsTransformed = self.PDIST_transform[ClosestVerts.astype(int) - 1]
|
| 1181 |
+
ClosestVertsTransformed[ClosestVerts < 0] = 0
|
| 1182 |
+
return ClosestVertsTransformed
|
| 1183 |
+
|
| 1184 |
+
def findClosestVertsCse(self, embedding, py, px, mask, mesh_name):
|
| 1185 |
+
mesh_vertex_embeddings = self.embedder(mesh_name)
|
| 1186 |
+
pixel_embeddings = embedding[:, py, px].t().to(device="cuda")
|
| 1187 |
+
mask_vals = mask[py, px]
|
| 1188 |
+
edm = squared_euclidean_distance_matrix(pixel_embeddings, mesh_vertex_embeddings)
|
| 1189 |
+
vertex_indices = edm.argmin(dim=1).cpu()
|
| 1190 |
+
vertex_indices[mask_vals <= 0] = -1
|
| 1191 |
+
return vertex_indices
|
| 1192 |
+
|
| 1193 |
+
def findAllClosestVertsGT(self, gt):
|
| 1194 |
+
#
|
| 1195 |
+
I_gt = np.array(gt["dp_I"])
|
| 1196 |
+
U_gt = np.array(gt["dp_U"])
|
| 1197 |
+
V_gt = np.array(gt["dp_V"])
|
| 1198 |
+
#
|
| 1199 |
+
# print(I_gt)
|
| 1200 |
+
#
|
| 1201 |
+
ClosestVertsGT = np.ones(I_gt.shape) * -1
|
| 1202 |
+
for i in np.arange(24):
|
| 1203 |
+
if (i + 1) in I_gt:
|
| 1204 |
+
UVs = np.array([U_gt[I_gt == (i + 1)], V_gt[I_gt == (i + 1)]])
|
| 1205 |
+
Current_Part_UVs = self.Part_UVs[i]
|
| 1206 |
+
Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
|
| 1207 |
+
D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
|
| 1208 |
+
ClosestVertsGT[I_gt == (i + 1)] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)]
|
| 1209 |
+
#
|
| 1210 |
+
ClosestVertsGTTransformed = self.PDIST_transform[ClosestVertsGT.astype(int) - 1]
|
| 1211 |
+
ClosestVertsGTTransformed[ClosestVertsGT < 0] = 0
|
| 1212 |
+
return ClosestVertsGT, ClosestVertsGTTransformed
|
| 1213 |
+
|
| 1214 |
+
def getDistancesCse(self, cVertsGT, cVerts, mesh_name):
|
| 1215 |
+
geodists_vertices = torch.ones_like(cVertsGT) * float("inf")
|
| 1216 |
+
selected = (cVertsGT >= 0) * (cVerts >= 0)
|
| 1217 |
+
mesh = create_mesh(mesh_name, "cpu")
|
| 1218 |
+
geodists_vertices[selected] = mesh.geodists[cVertsGT[selected], cVerts[selected]]
|
| 1219 |
+
return geodists_vertices.numpy()
|
| 1220 |
+
|
| 1221 |
+
def getDistancesUV(self, cVertsGT, cVerts):
|
| 1222 |
+
#
|
| 1223 |
+
n = 27554
|
| 1224 |
+
dists = []
|
| 1225 |
+
for d in range(len(cVertsGT)):
|
| 1226 |
+
if cVertsGT[d] > 0:
|
| 1227 |
+
if cVerts[d] > 0:
|
| 1228 |
+
i = cVertsGT[d] - 1
|
| 1229 |
+
j = cVerts[d] - 1
|
| 1230 |
+
if j == i:
|
| 1231 |
+
dists.append(0)
|
| 1232 |
+
elif j > i:
|
| 1233 |
+
ccc = i
|
| 1234 |
+
i = j
|
| 1235 |
+
j = ccc
|
| 1236 |
+
i = n - i - 1
|
| 1237 |
+
j = n - j - 1
|
| 1238 |
+
k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
|
| 1239 |
+
k = (n * n - n) / 2 - k - 1
|
| 1240 |
+
dists.append(self.Pdist_matrix[int(k)][0])
|
| 1241 |
+
else:
|
| 1242 |
+
i = n - i - 1
|
| 1243 |
+
j = n - j - 1
|
| 1244 |
+
k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
|
| 1245 |
+
k = (n * n - n) / 2 - k - 1
|
| 1246 |
+
dists.append(self.Pdist_matrix[int(k)][0])
|
| 1247 |
+
else:
|
| 1248 |
+
dists.append(np.inf)
|
| 1249 |
+
return np.atleast_1d(np.array(dists).squeeze())
|
| 1250 |
+
|
| 1251 |
+
|
| 1252 |
+
class Params:
|
| 1253 |
+
"""
|
| 1254 |
+
Params for coco evaluation api
|
| 1255 |
+
"""
|
| 1256 |
+
|
| 1257 |
+
def setDetParams(self):
|
| 1258 |
+
self.imgIds = []
|
| 1259 |
+
self.catIds = []
|
| 1260 |
+
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
| 1261 |
+
self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
|
| 1262 |
+
self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
|
| 1263 |
+
self.maxDets = [1, 10, 100]
|
| 1264 |
+
self.areaRng = [
|
| 1265 |
+
[0**2, 1e5**2],
|
| 1266 |
+
[0**2, 32**2],
|
| 1267 |
+
[32**2, 96**2],
|
| 1268 |
+
[96**2, 1e5**2],
|
| 1269 |
+
]
|
| 1270 |
+
self.areaRngLbl = ["all", "small", "medium", "large"]
|
| 1271 |
+
self.useCats = 1
|
| 1272 |
+
|
| 1273 |
+
def setKpParams(self):
|
| 1274 |
+
self.imgIds = []
|
| 1275 |
+
self.catIds = []
|
| 1276 |
+
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
| 1277 |
+
self.iouThrs = np.linspace(0.5, 0.95, np.round((0.95 - 0.5) / 0.05) + 1, endpoint=True)
|
| 1278 |
+
self.recThrs = np.linspace(0.0, 1.00, np.round((1.00 - 0.0) / 0.01) + 1, endpoint=True)
|
| 1279 |
+
self.maxDets = [20]
|
| 1280 |
+
self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
|
| 1281 |
+
self.areaRngLbl = ["all", "medium", "large"]
|
| 1282 |
+
self.useCats = 1
|
| 1283 |
+
|
| 1284 |
+
def setUvParams(self):
|
| 1285 |
+
self.imgIds = []
|
| 1286 |
+
self.catIds = []
|
| 1287 |
+
self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
|
| 1288 |
+
self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
|
| 1289 |
+
self.maxDets = [20]
|
| 1290 |
+
self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
|
| 1291 |
+
self.areaRngLbl = ["all", "medium", "large"]
|
| 1292 |
+
self.useCats = 1
|
| 1293 |
+
|
| 1294 |
+
def __init__(self, iouType="segm"):
|
| 1295 |
+
if iouType == "segm" or iouType == "bbox":
|
| 1296 |
+
self.setDetParams()
|
| 1297 |
+
elif iouType == "keypoints":
|
| 1298 |
+
self.setKpParams()
|
| 1299 |
+
elif iouType == "densepose":
|
| 1300 |
+
self.setUvParams()
|
| 1301 |
+
else:
|
| 1302 |
+
raise Exception("iouType not supported")
|
| 1303 |
+
self.iouType = iouType
|
| 1304 |
+
# useSegm is deprecated
|
| 1305 |
+
self.useSegm = None
|
densepose/evaluation/evaluator.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
# pyre-unsafe
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
import copy
|
| 8 |
+
import io
|
| 9 |
+
import itertools
|
| 10 |
+
import logging
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from typing import Dict, Iterable, List, Optional
|
| 15 |
+
import pycocotools.mask as mask_utils
|
| 16 |
+
import torch
|
| 17 |
+
from pycocotools.coco import COCO
|
| 18 |
+
from tabulate import tabulate
|
| 19 |
+
|
| 20 |
+
from detectron2.config import CfgNode
|
| 21 |
+
from detectron2.data import MetadataCatalog
|
| 22 |
+
from detectron2.evaluation import DatasetEvaluator
|
| 23 |
+
from detectron2.structures import BoxMode
|
| 24 |
+
from detectron2.utils.comm import gather, get_rank, is_main_process, synchronize
|
| 25 |
+
from detectron2.utils.file_io import PathManager
|
| 26 |
+
from detectron2.utils.logger import create_small_table
|
| 27 |
+
|
| 28 |
+
from densepose.converters import ToChartResultConverter, ToMaskConverter
|
| 29 |
+
from densepose.data.datasets.coco import maybe_filter_and_map_categories_cocoapi
|
| 30 |
+
from densepose.structures import (
|
| 31 |
+
DensePoseChartPredictorOutput,
|
| 32 |
+
DensePoseEmbeddingPredictorOutput,
|
| 33 |
+
quantize_densepose_chart_result,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from .densepose_coco_evaluation import DensePoseCocoEval, DensePoseEvalMode
|
| 37 |
+
from .mesh_alignment_evaluator import MeshAlignmentEvaluator
|
| 38 |
+
from .tensor_storage import (
|
| 39 |
+
SingleProcessFileTensorStorage,
|
| 40 |
+
SingleProcessRamTensorStorage,
|
| 41 |
+
SingleProcessTensorStorage,
|
| 42 |
+
SizeData,
|
| 43 |
+
storage_gather,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DensePoseCOCOEvaluator(DatasetEvaluator):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
dataset_name,
|
| 51 |
+
distributed,
|
| 52 |
+
output_dir=None,
|
| 53 |
+
evaluator_type: str = "iuv",
|
| 54 |
+
min_iou_threshold: float = 0.5,
|
| 55 |
+
storage: Optional[SingleProcessTensorStorage] = None,
|
| 56 |
+
embedder=None,
|
| 57 |
+
should_evaluate_mesh_alignment: bool = False,
|
| 58 |
+
mesh_alignment_mesh_names: Optional[List[str]] = None,
|
| 59 |
+
):
|
| 60 |
+
self._embedder = embedder
|
| 61 |
+
self._distributed = distributed
|
| 62 |
+
self._output_dir = output_dir
|
| 63 |
+
self._evaluator_type = evaluator_type
|
| 64 |
+
self._storage = storage
|
| 65 |
+
self._should_evaluate_mesh_alignment = should_evaluate_mesh_alignment
|
| 66 |
+
|
| 67 |
+
assert not (
|
| 68 |
+
should_evaluate_mesh_alignment and embedder is None
|
| 69 |
+
), "Mesh alignment evaluation is activated, but no vertex embedder provided!"
|
| 70 |
+
if should_evaluate_mesh_alignment:
|
| 71 |
+
self._mesh_alignment_evaluator = MeshAlignmentEvaluator(
|
| 72 |
+
embedder,
|
| 73 |
+
mesh_alignment_mesh_names,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self._cpu_device = torch.device("cpu")
|
| 77 |
+
self._logger = logging.getLogger(__name__)
|
| 78 |
+
|
| 79 |
+
self._metadata = MetadataCatalog.get(dataset_name)
|
| 80 |
+
self._min_threshold = min_iou_threshold
|
| 81 |
+
json_file = PathManager.get_local_path(self._metadata.json_file)
|
| 82 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 83 |
+
self._coco_api = COCO(json_file)
|
| 84 |
+
maybe_filter_and_map_categories_cocoapi(dataset_name, self._coco_api)
|
| 85 |
+
|
| 86 |
+
def reset(self):
|
| 87 |
+
self._predictions = []
|
| 88 |
+
|
| 89 |
+
def process(self, inputs, outputs):
|
| 90 |
+
"""
|
| 91 |
+
Args:
|
| 92 |
+
inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
|
| 93 |
+
It is a list of dict. Each dict corresponds to an image and
|
| 94 |
+
contains keys like "height", "width", "file_name", "image_id".
|
| 95 |
+
outputs: the outputs of a COCO model. It is a list of dicts with key
|
| 96 |
+
"instances" that contains :class:`Instances`.
|
| 97 |
+
The :class:`Instances` object needs to have `densepose` field.
|
| 98 |
+
"""
|
| 99 |
+
for input, output in zip(inputs, outputs):
|
| 100 |
+
instances = output["instances"].to(self._cpu_device)
|
| 101 |
+
if not instances.has("pred_densepose"):
|
| 102 |
+
continue
|
| 103 |
+
prediction_list = prediction_to_dict(
|
| 104 |
+
instances,
|
| 105 |
+
input["image_id"],
|
| 106 |
+
self._embedder,
|
| 107 |
+
self._metadata.class_to_mesh_name,
|
| 108 |
+
self._storage is not None,
|
| 109 |
+
)
|
| 110 |
+
if self._storage is not None:
|
| 111 |
+
for prediction_dict in prediction_list:
|
| 112 |
+
dict_to_store = {}
|
| 113 |
+
for field_name in self._storage.data_schema:
|
| 114 |
+
dict_to_store[field_name] = prediction_dict[field_name]
|
| 115 |
+
record_id = self._storage.put(dict_to_store)
|
| 116 |
+
prediction_dict["record_id"] = record_id
|
| 117 |
+
prediction_dict["rank"] = get_rank()
|
| 118 |
+
for field_name in self._storage.data_schema:
|
| 119 |
+
del prediction_dict[field_name]
|
| 120 |
+
self._predictions.extend(prediction_list)
|
| 121 |
+
|
| 122 |
+
def evaluate(self, img_ids=None):
|
| 123 |
+
if self._distributed:
|
| 124 |
+
synchronize()
|
| 125 |
+
predictions = gather(self._predictions)
|
| 126 |
+
predictions = list(itertools.chain(*predictions))
|
| 127 |
+
else:
|
| 128 |
+
predictions = self._predictions
|
| 129 |
+
|
| 130 |
+
multi_storage = storage_gather(self._storage) if self._storage is not None else None
|
| 131 |
+
|
| 132 |
+
if not is_main_process():
|
| 133 |
+
return
|
| 134 |
+
return copy.deepcopy(self._eval_predictions(predictions, multi_storage, img_ids))
|
| 135 |
+
|
| 136 |
+
def _eval_predictions(self, predictions, multi_storage=None, img_ids=None):
|
| 137 |
+
"""
|
| 138 |
+
Evaluate predictions on densepose.
|
| 139 |
+
Return results with the metrics of the tasks.
|
| 140 |
+
"""
|
| 141 |
+
self._logger.info("Preparing results for COCO format ...")
|
| 142 |
+
|
| 143 |
+
if self._output_dir:
|
| 144 |
+
PathManager.mkdirs(self._output_dir)
|
| 145 |
+
file_path = os.path.join(self._output_dir, "coco_densepose_predictions.pth")
|
| 146 |
+
with PathManager.open(file_path, "wb") as f:
|
| 147 |
+
torch.save(predictions, f)
|
| 148 |
+
|
| 149 |
+
self._logger.info("Evaluating predictions ...")
|
| 150 |
+
res = OrderedDict()
|
| 151 |
+
results_gps, results_gpsm, results_segm = _evaluate_predictions_on_coco(
|
| 152 |
+
self._coco_api,
|
| 153 |
+
predictions,
|
| 154 |
+
multi_storage,
|
| 155 |
+
self._embedder,
|
| 156 |
+
class_names=self._metadata.get("thing_classes"),
|
| 157 |
+
min_threshold=self._min_threshold,
|
| 158 |
+
img_ids=img_ids,
|
| 159 |
+
)
|
| 160 |
+
res["densepose_gps"] = results_gps
|
| 161 |
+
res["densepose_gpsm"] = results_gpsm
|
| 162 |
+
res["densepose_segm"] = results_segm
|
| 163 |
+
if self._should_evaluate_mesh_alignment:
|
| 164 |
+
res["densepose_mesh_alignment"] = self._evaluate_mesh_alignment()
|
| 165 |
+
return res
|
| 166 |
+
|
| 167 |
+
def _evaluate_mesh_alignment(self):
|
| 168 |
+
self._logger.info("Mesh alignment evaluation ...")
|
| 169 |
+
mean_ge, mean_gps, per_mesh_metrics = self._mesh_alignment_evaluator.evaluate()
|
| 170 |
+
results = {
|
| 171 |
+
"GE": mean_ge * 100,
|
| 172 |
+
"GPS": mean_gps * 100,
|
| 173 |
+
}
|
| 174 |
+
mesh_names = set()
|
| 175 |
+
for metric_name in per_mesh_metrics:
|
| 176 |
+
for mesh_name, value in per_mesh_metrics[metric_name].items():
|
| 177 |
+
results[f"{metric_name}-{mesh_name}"] = value * 100
|
| 178 |
+
mesh_names.add(mesh_name)
|
| 179 |
+
self._print_mesh_alignment_results(results, mesh_names)
|
| 180 |
+
return results
|
| 181 |
+
|
| 182 |
+
def _print_mesh_alignment_results(self, results: Dict[str, float], mesh_names: Iterable[str]):
|
| 183 |
+
self._logger.info("Evaluation results for densepose, mesh alignment:")
|
| 184 |
+
self._logger.info(f'| {"Mesh":13s} | {"GErr":7s} | {"GPS":7s} |')
|
| 185 |
+
self._logger.info("| :-----------: | :-----: | :-----: |")
|
| 186 |
+
for mesh_name in mesh_names:
|
| 187 |
+
ge_key = f"GE-{mesh_name}"
|
| 188 |
+
ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " "
|
| 189 |
+
gps_key = f"GPS-{mesh_name}"
|
| 190 |
+
gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " "
|
| 191 |
+
self._logger.info(f"| {mesh_name:13s} | {ge_str:7s} | {gps_str:7s} |")
|
| 192 |
+
self._logger.info("| :-------------------------------: |")
|
| 193 |
+
ge_key = "GE"
|
| 194 |
+
ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " "
|
| 195 |
+
gps_key = "GPS"
|
| 196 |
+
gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " "
|
| 197 |
+
self._logger.info(f'| {"MEAN":13s} | {ge_str:7s} | {gps_str:7s} |')
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def prediction_to_dict(instances, img_id, embedder, class_to_mesh_name, use_storage):
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
instances (Instances): the output of the model
|
| 204 |
+
img_id (str): the image id in COCO
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
list[dict]: the results in densepose evaluation format
|
| 208 |
+
"""
|
| 209 |
+
scores = instances.scores.tolist()
|
| 210 |
+
classes = instances.pred_classes.tolist()
|
| 211 |
+
raw_boxes_xywh = BoxMode.convert(
|
| 212 |
+
instances.pred_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if isinstance(instances.pred_densepose, DensePoseEmbeddingPredictorOutput):
|
| 216 |
+
results_densepose = densepose_cse_predictions_to_dict(
|
| 217 |
+
instances, embedder, class_to_mesh_name, use_storage
|
| 218 |
+
)
|
| 219 |
+
elif isinstance(instances.pred_densepose, DensePoseChartPredictorOutput):
|
| 220 |
+
if not use_storage:
|
| 221 |
+
results_densepose = densepose_chart_predictions_to_dict(instances)
|
| 222 |
+
else:
|
| 223 |
+
results_densepose = densepose_chart_predictions_to_storage_dict(instances)
|
| 224 |
+
|
| 225 |
+
results = []
|
| 226 |
+
for k in range(len(instances)):
|
| 227 |
+
result = {
|
| 228 |
+
"image_id": img_id,
|
| 229 |
+
"category_id": classes[k],
|
| 230 |
+
"bbox": raw_boxes_xywh[k].tolist(),
|
| 231 |
+
"score": scores[k],
|
| 232 |
+
}
|
| 233 |
+
results.append({**result, **results_densepose[k]})
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def densepose_chart_predictions_to_dict(instances):
|
| 238 |
+
segmentations = ToMaskConverter.convert(
|
| 239 |
+
instances.pred_densepose, instances.pred_boxes, instances.image_size
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
results = []
|
| 243 |
+
for k in range(len(instances)):
|
| 244 |
+
densepose_results_quantized = quantize_densepose_chart_result(
|
| 245 |
+
ToChartResultConverter.convert(instances.pred_densepose[k], instances.pred_boxes[k])
|
| 246 |
+
)
|
| 247 |
+
densepose_results_quantized.labels_uv_uint8 = (
|
| 248 |
+
densepose_results_quantized.labels_uv_uint8.cpu()
|
| 249 |
+
)
|
| 250 |
+
segmentation = segmentations.tensor[k]
|
| 251 |
+
segmentation_encoded = mask_utils.encode(
|
| 252 |
+
np.require(segmentation.numpy(), dtype=np.uint8, requirements=["F"])
|
| 253 |
+
)
|
| 254 |
+
segmentation_encoded["counts"] = segmentation_encoded["counts"].decode("utf-8")
|
| 255 |
+
result = {
|
| 256 |
+
"densepose": densepose_results_quantized,
|
| 257 |
+
"segmentation": segmentation_encoded,
|
| 258 |
+
}
|
| 259 |
+
results.append(result)
|
| 260 |
+
return results
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def densepose_chart_predictions_to_storage_dict(instances):
|
| 264 |
+
results = []
|
| 265 |
+
for k in range(len(instances)):
|
| 266 |
+
densepose_predictor_output = instances.pred_densepose[k]
|
| 267 |
+
result = {
|
| 268 |
+
"coarse_segm": densepose_predictor_output.coarse_segm.squeeze(0).cpu(),
|
| 269 |
+
"fine_segm": densepose_predictor_output.fine_segm.squeeze(0).cpu(),
|
| 270 |
+
"u": densepose_predictor_output.u.squeeze(0).cpu(),
|
| 271 |
+
"v": densepose_predictor_output.v.squeeze(0).cpu(),
|
| 272 |
+
}
|
| 273 |
+
results.append(result)
|
| 274 |
+
return results
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def densepose_cse_predictions_to_dict(instances, embedder, class_to_mesh_name, use_storage):
|
| 278 |
+
results = []
|
| 279 |
+
for k in range(len(instances)):
|
| 280 |
+
cse = instances.pred_densepose[k]
|
| 281 |
+
results.append(
|
| 282 |
+
{
|
| 283 |
+
"coarse_segm": cse.coarse_segm[0].cpu(),
|
| 284 |
+
"embedding": cse.embedding[0].cpu(),
|
| 285 |
+
}
|
| 286 |
+
)
|
| 287 |
+
return results
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _evaluate_predictions_on_coco(
|
| 291 |
+
coco_gt,
|
| 292 |
+
coco_results,
|
| 293 |
+
multi_storage=None,
|
| 294 |
+
embedder=None,
|
| 295 |
+
class_names=None,
|
| 296 |
+
min_threshold: float = 0.5,
|
| 297 |
+
img_ids=None,
|
| 298 |
+
):
|
| 299 |
+
logger = logging.getLogger(__name__)
|
| 300 |
+
|
| 301 |
+
densepose_metrics = _get_densepose_metrics(min_threshold)
|
| 302 |
+
if len(coco_results) == 0: # cocoapi does not handle empty results very well
|
| 303 |
+
logger.warn("No predictions from the model! Set scores to -1")
|
| 304 |
+
results_gps = {metric: -1 for metric in densepose_metrics}
|
| 305 |
+
results_gpsm = {metric: -1 for metric in densepose_metrics}
|
| 306 |
+
results_segm = {metric: -1 for metric in densepose_metrics}
|
| 307 |
+
return results_gps, results_gpsm, results_segm
|
| 308 |
+
|
| 309 |
+
coco_dt = coco_gt.loadRes(coco_results)
|
| 310 |
+
|
| 311 |
+
results = []
|
| 312 |
+
for eval_mode_name in ["GPS", "GPSM", "IOU"]:
|
| 313 |
+
eval_mode = getattr(DensePoseEvalMode, eval_mode_name)
|
| 314 |
+
coco_eval = DensePoseCocoEval(
|
| 315 |
+
coco_gt, coco_dt, "densepose", multi_storage, embedder, dpEvalMode=eval_mode
|
| 316 |
+
)
|
| 317 |
+
result = _derive_results_from_coco_eval(
|
| 318 |
+
coco_eval, eval_mode_name, densepose_metrics, class_names, min_threshold, img_ids
|
| 319 |
+
)
|
| 320 |
+
results.append(result)
|
| 321 |
+
return results
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _get_densepose_metrics(min_threshold: float = 0.5):
|
| 325 |
+
metrics = ["AP"]
|
| 326 |
+
if min_threshold <= 0.201:
|
| 327 |
+
metrics += ["AP20"]
|
| 328 |
+
if min_threshold <= 0.301:
|
| 329 |
+
metrics += ["AP30"]
|
| 330 |
+
if min_threshold <= 0.401:
|
| 331 |
+
metrics += ["AP40"]
|
| 332 |
+
metrics.extend(["AP50", "AP75", "APm", "APl", "AR", "AR50", "AR75", "ARm", "ARl"])
|
| 333 |
+
return metrics
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _derive_results_from_coco_eval(
|
| 337 |
+
coco_eval, eval_mode_name, metrics, class_names, min_threshold: float, img_ids
|
| 338 |
+
):
|
| 339 |
+
if img_ids is not None:
|
| 340 |
+
coco_eval.params.imgIds = img_ids
|
| 341 |
+
coco_eval.params.iouThrs = np.linspace(
|
| 342 |
+
min_threshold, 0.95, int(np.round((0.95 - min_threshold) / 0.05)) + 1, endpoint=True
|
| 343 |
+
)
|
| 344 |
+
coco_eval.evaluate()
|
| 345 |
+
coco_eval.accumulate()
|
| 346 |
+
coco_eval.summarize()
|
| 347 |
+
results = {metric: float(coco_eval.stats[idx] * 100) for idx, metric in enumerate(metrics)}
|
| 348 |
+
logger = logging.getLogger(__name__)
|
| 349 |
+
logger.info(
|
| 350 |
+
f"Evaluation results for densepose, {eval_mode_name} metric: \n"
|
| 351 |
+
+ create_small_table(results)
|
| 352 |
+
)
|
| 353 |
+
if class_names is None or len(class_names) <= 1:
|
| 354 |
+
return results
|
| 355 |
+
|
| 356 |
+
# Compute per-category AP, the same way as it is done in D2
|
| 357 |
+
# (see detectron2/evaluation/coco_evaluation.py):
|
| 358 |
+
precisions = coco_eval.eval["precision"]
|
| 359 |
+
# precision has dims (iou, recall, cls, area range, max dets)
|
| 360 |
+
assert len(class_names) == precisions.shape[2]
|
| 361 |
+
|
| 362 |
+
results_per_category = []
|
| 363 |
+
for idx, name in enumerate(class_names):
|
| 364 |
+
# area range index 0: all area ranges
|
| 365 |
+
# max dets index -1: typically 100 per image
|
| 366 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 367 |
+
precision = precision[precision > -1]
|
| 368 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
| 369 |
+
results_per_category.append((f"{name}", float(ap * 100)))
|
| 370 |
+
|
| 371 |
+
# tabulate it
|
| 372 |
+
n_cols = min(6, len(results_per_category) * 2)
|
| 373 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 374 |
+
results_2d = itertools.zip_longest(*[results_flatten[i::n_cols] for i in range(n_cols)])
|
| 375 |
+
table = tabulate(
|
| 376 |
+
results_2d,
|
| 377 |
+
tablefmt="pipe",
|
| 378 |
+
floatfmt=".3f",
|
| 379 |
+
headers=["category", "AP"] * (n_cols // 2),
|
| 380 |
+
numalign="left",
|
| 381 |
+
)
|
| 382 |
+
logger.info(f"Per-category {eval_mode_name} AP: \n" + table)
|
| 383 |
+
|
| 384 |
+
results.update({"AP-" + name: ap for name, ap in results_per_category})
|
| 385 |
+
return results
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def build_densepose_evaluator_storage(cfg: CfgNode, output_folder: str):
|
| 389 |
+
storage_spec = cfg.DENSEPOSE_EVALUATION.STORAGE
|
| 390 |
+
if storage_spec == "none":
|
| 391 |
+
return None
|
| 392 |
+
evaluator_type = cfg.DENSEPOSE_EVALUATION.TYPE
|
| 393 |
+
# common output tensor sizes
|
| 394 |
+
hout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
|
| 395 |
+
wout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
|
| 396 |
+
n_csc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
|
| 397 |
+
# specific output tensors
|
| 398 |
+
if evaluator_type == "iuv":
|
| 399 |
+
n_fsc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1
|
| 400 |
+
schema = {
|
| 401 |
+
"coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)),
|
| 402 |
+
"fine_segm": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
|
| 403 |
+
"u": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
|
| 404 |
+
"v": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
|
| 405 |
+
}
|
| 406 |
+
elif evaluator_type == "cse":
|
| 407 |
+
embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
|
| 408 |
+
schema = {
|
| 409 |
+
"coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)),
|
| 410 |
+
"embedding": SizeData(dtype="float32", shape=(embed_size, hout, wout)),
|
| 411 |
+
}
|
| 412 |
+
else:
|
| 413 |
+
raise ValueError(f"Unknown evaluator type: {evaluator_type}")
|
| 414 |
+
# storage types
|
| 415 |
+
if storage_spec == "ram":
|
| 416 |
+
storage = SingleProcessRamTensorStorage(schema, io.BytesIO())
|
| 417 |
+
elif storage_spec == "file":
|
| 418 |
+
fpath = os.path.join(output_folder, f"DensePoseEvaluatorStorage.{get_rank()}.bin")
|
| 419 |
+
PathManager.mkdirs(output_folder)
|
| 420 |
+
storage = SingleProcessFileTensorStorage(schema, fpath, "wb")
|
| 421 |
+
else:
|
| 422 |
+
raise ValueError(f"Unknown storage specification: {storage_spec}")
|
| 423 |
+
return storage
|
densepose/evaluation/mesh_alignment_evaluator.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from detectron2.utils.file_io import PathManager
|
| 12 |
+
|
| 13 |
+
from densepose.structures.mesh import create_mesh
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MeshAlignmentEvaluator:
|
| 17 |
+
"""
|
| 18 |
+
Class for evaluation of 3D mesh alignment based on the learned vertex embeddings
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]):
|
| 22 |
+
self.embedder = embedder
|
| 23 |
+
# use the provided mesh names if not None and not an empty list
|
| 24 |
+
self.mesh_names = mesh_names if mesh_names else embedder.mesh_names
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
with PathManager.open(
|
| 27 |
+
"https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r"
|
| 28 |
+
) as f:
|
| 29 |
+
self.mesh_keyvertices = json.load(f)
|
| 30 |
+
|
| 31 |
+
def evaluate(self):
|
| 32 |
+
ge_per_mesh = {}
|
| 33 |
+
gps_per_mesh = {}
|
| 34 |
+
for mesh_name_1 in self.mesh_names:
|
| 35 |
+
avg_errors = []
|
| 36 |
+
avg_gps = []
|
| 37 |
+
embeddings_1 = self.embedder(mesh_name_1)
|
| 38 |
+
keyvertices_1 = self.mesh_keyvertices[mesh_name_1]
|
| 39 |
+
keyvertex_names_1 = list(keyvertices_1.keys())
|
| 40 |
+
keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1]
|
| 41 |
+
for mesh_name_2 in self.mesh_names:
|
| 42 |
+
if mesh_name_1 == mesh_name_2:
|
| 43 |
+
continue
|
| 44 |
+
embeddings_2 = self.embedder(mesh_name_2)
|
| 45 |
+
keyvertices_2 = self.mesh_keyvertices[mesh_name_2]
|
| 46 |
+
sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T)
|
| 47 |
+
vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1)
|
| 48 |
+
mesh_2 = create_mesh(mesh_name_2, embeddings_2.device)
|
| 49 |
+
geodists = mesh_2.geodists[
|
| 50 |
+
vertices_2_matching_keyvertices_1,
|
| 51 |
+
[keyvertices_2[name] for name in keyvertex_names_1],
|
| 52 |
+
]
|
| 53 |
+
Current_Mean_Distances = 0.255
|
| 54 |
+
gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp()
|
| 55 |
+
avg_errors.append(geodists.mean().item())
|
| 56 |
+
avg_gps.append(gps.mean().item())
|
| 57 |
+
|
| 58 |
+
ge_mean = torch.as_tensor(avg_errors).mean().item()
|
| 59 |
+
gps_mean = torch.as_tensor(avg_gps).mean().item()
|
| 60 |
+
ge_per_mesh[mesh_name_1] = ge_mean
|
| 61 |
+
gps_per_mesh[mesh_name_1] = gps_mean
|
| 62 |
+
ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item()
|
| 63 |
+
gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item()
|
| 64 |
+
per_mesh_metrics = {
|
| 65 |
+
"GE": ge_per_mesh,
|
| 66 |
+
"GPS": gps_per_mesh,
|
| 67 |
+
}
|
| 68 |
+
return ge_mean_global, gps_mean_global, per_mesh_metrics
|
densepose/evaluation/tensor_storage.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import reduce
|
| 10 |
+
from operator import mul
|
| 11 |
+
from typing import BinaryIO, Dict, Optional, Tuple
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from detectron2.utils.comm import gather, get_rank
|
| 15 |
+
from detectron2.utils.file_io import PathManager
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class SizeData:
|
| 20 |
+
dtype: str
|
| 21 |
+
shape: Tuple[int]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _calculate_record_field_size_b(data_schema: Dict[str, SizeData], field_name: str) -> int:
|
| 25 |
+
schema = data_schema[field_name]
|
| 26 |
+
element_size_b = np.dtype(schema.dtype).itemsize
|
| 27 |
+
record_field_size_b = reduce(mul, schema.shape) * element_size_b
|
| 28 |
+
return record_field_size_b
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _calculate_record_size_b(data_schema: Dict[str, SizeData]) -> int:
|
| 32 |
+
record_size_b = 0
|
| 33 |
+
for field_name in data_schema:
|
| 34 |
+
record_field_size_b = _calculate_record_field_size_b(data_schema, field_name)
|
| 35 |
+
record_size_b += record_field_size_b
|
| 36 |
+
return record_size_b
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _calculate_record_field_sizes_b(data_schema: Dict[str, SizeData]) -> Dict[str, int]:
|
| 40 |
+
field_sizes_b = {}
|
| 41 |
+
for field_name in data_schema:
|
| 42 |
+
field_sizes_b[field_name] = _calculate_record_field_size_b(data_schema, field_name)
|
| 43 |
+
return field_sizes_b
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SingleProcessTensorStorage:
|
| 47 |
+
"""
|
| 48 |
+
Compact tensor storage to keep tensor data of predefined size and type.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, data_schema: Dict[str, SizeData], storage_impl: BinaryIO):
|
| 52 |
+
"""
|
| 53 |
+
Construct tensor storage based on information on data shape and size.
|
| 54 |
+
Internally uses numpy to interpret the type specification.
|
| 55 |
+
The storage must support operations `seek(offset, whence=os.SEEK_SET)` and
|
| 56 |
+
`read(size)` to be able to perform the `get` operation.
|
| 57 |
+
The storage must support operation `write(bytes)` to be able to perform
|
| 58 |
+
the `put` operation.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data_schema (dict: str -> SizeData): dictionary which maps tensor name
|
| 62 |
+
to its size data (shape and data type), e.g.
|
| 63 |
+
```
|
| 64 |
+
{
|
| 65 |
+
"coarse_segm": SizeData(dtype="float32", shape=(112, 112)),
|
| 66 |
+
"embedding": SizeData(dtype="float32", shape=(16, 112, 112)),
|
| 67 |
+
}
|
| 68 |
+
```
|
| 69 |
+
storage_impl (BinaryIO): io instance that handles file-like seek, read
|
| 70 |
+
and write operations, e.g. a file handle or a memory buffer like io.BytesIO
|
| 71 |
+
"""
|
| 72 |
+
self.data_schema = data_schema
|
| 73 |
+
self.record_size_b = _calculate_record_size_b(data_schema)
|
| 74 |
+
self.record_field_sizes_b = _calculate_record_field_sizes_b(data_schema)
|
| 75 |
+
self.storage_impl = storage_impl
|
| 76 |
+
self.next_record_id = 0
|
| 77 |
+
|
| 78 |
+
def get(self, record_id: int) -> Dict[str, torch.Tensor]:
|
| 79 |
+
"""
|
| 80 |
+
Load tensors from the storage by record ID
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
record_id (int): Record ID, for which to load the data
|
| 84 |
+
|
| 85 |
+
Return:
|
| 86 |
+
dict: str -> tensor: tensor name mapped to tensor data, recorded under the provided ID
|
| 87 |
+
"""
|
| 88 |
+
self.storage_impl.seek(record_id * self.record_size_b, os.SEEK_SET)
|
| 89 |
+
data_bytes = self.storage_impl.read(self.record_size_b)
|
| 90 |
+
assert len(data_bytes) == self.record_size_b, (
|
| 91 |
+
f"Expected data size {self.record_size_b} B could not be read: "
|
| 92 |
+
f"got {len(data_bytes)} B"
|
| 93 |
+
)
|
| 94 |
+
record = {}
|
| 95 |
+
cur_idx = 0
|
| 96 |
+
# it's important to read and write in the same order
|
| 97 |
+
for field_name in sorted(self.data_schema):
|
| 98 |
+
schema = self.data_schema[field_name]
|
| 99 |
+
field_size_b = self.record_field_sizes_b[field_name]
|
| 100 |
+
chunk = data_bytes[cur_idx : cur_idx + field_size_b]
|
| 101 |
+
data_np = np.frombuffer(
|
| 102 |
+
chunk, dtype=schema.dtype, count=reduce(mul, schema.shape)
|
| 103 |
+
).reshape(schema.shape)
|
| 104 |
+
record[field_name] = torch.from_numpy(data_np)
|
| 105 |
+
cur_idx += field_size_b
|
| 106 |
+
return record
|
| 107 |
+
|
| 108 |
+
def put(self, data: Dict[str, torch.Tensor]) -> int:
|
| 109 |
+
"""
|
| 110 |
+
Store tensors in the storage
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
data (dict: str -> tensor): data to store, a dictionary which maps
|
| 114 |
+
tensor names into tensors; tensor shapes must match those specified
|
| 115 |
+
in data schema.
|
| 116 |
+
Return:
|
| 117 |
+
int: record ID, under which the data is stored
|
| 118 |
+
"""
|
| 119 |
+
# it's important to read and write in the same order
|
| 120 |
+
for field_name in sorted(self.data_schema):
|
| 121 |
+
assert (
|
| 122 |
+
field_name in data
|
| 123 |
+
), f"Field '{field_name}' not present in data: data keys are {data.keys()}"
|
| 124 |
+
value = data[field_name]
|
| 125 |
+
assert value.shape == self.data_schema[field_name].shape, (
|
| 126 |
+
f"Mismatched tensor shapes for field '{field_name}': "
|
| 127 |
+
f"expected {self.data_schema[field_name].shape}, got {value.shape}"
|
| 128 |
+
)
|
| 129 |
+
data_bytes = value.cpu().numpy().tobytes()
|
| 130 |
+
assert len(data_bytes) == self.record_field_sizes_b[field_name], (
|
| 131 |
+
f"Expected field {field_name} to be of size "
|
| 132 |
+
f"{self.record_field_sizes_b[field_name]} B, got {len(data_bytes)} B"
|
| 133 |
+
)
|
| 134 |
+
self.storage_impl.write(data_bytes)
|
| 135 |
+
record_id = self.next_record_id
|
| 136 |
+
self.next_record_id += 1
|
| 137 |
+
return record_id
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SingleProcessFileTensorStorage(SingleProcessTensorStorage):
|
| 141 |
+
"""
|
| 142 |
+
Implementation of a single process tensor storage which stores data in a file
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, data_schema: Dict[str, SizeData], fpath: str, mode: str):
|
| 146 |
+
self.fpath = fpath
|
| 147 |
+
assert "b" in mode, f"Tensor storage should be opened in binary mode, got '{mode}'"
|
| 148 |
+
if "w" in mode:
|
| 149 |
+
# pyre-fixme[6]: For 2nd argument expected `Union[typing_extensions.Liter...
|
| 150 |
+
file_h = PathManager.open(fpath, mode)
|
| 151 |
+
elif "r" in mode:
|
| 152 |
+
local_fpath = PathManager.get_local_path(fpath)
|
| 153 |
+
file_h = open(local_fpath, mode)
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Unsupported file mode {mode}, supported modes: rb, wb")
|
| 156 |
+
super().__init__(data_schema, file_h) # pyre-ignore[6]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class SingleProcessRamTensorStorage(SingleProcessTensorStorage):
|
| 160 |
+
"""
|
| 161 |
+
Implementation of a single process tensor storage which stores data in RAM
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, data_schema: Dict[str, SizeData], buf: io.BytesIO):
|
| 165 |
+
super().__init__(data_schema, buf)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class MultiProcessTensorStorage:
|
| 169 |
+
"""
|
| 170 |
+
Representation of a set of tensor storages created by individual processes,
|
| 171 |
+
allows to access those storages from a single owner process. The storages
|
| 172 |
+
should either be shared or broadcasted to the owner process.
|
| 173 |
+
The processes are identified by their rank, data is uniquely defined by
|
| 174 |
+
the rank of the process and the record ID.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, rank_to_storage: Dict[int, SingleProcessTensorStorage]):
|
| 178 |
+
self.rank_to_storage = rank_to_storage
|
| 179 |
+
|
| 180 |
+
def get(self, rank: int, record_id: int) -> Dict[str, torch.Tensor]:
|
| 181 |
+
storage = self.rank_to_storage[rank]
|
| 182 |
+
return storage.get(record_id)
|
| 183 |
+
|
| 184 |
+
def put(self, rank: int, data: Dict[str, torch.Tensor]) -> int:
|
| 185 |
+
storage = self.rank_to_storage[rank]
|
| 186 |
+
return storage.put(data)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class MultiProcessFileTensorStorage(MultiProcessTensorStorage):
|
| 190 |
+
def __init__(self, data_schema: Dict[str, SizeData], rank_to_fpath: Dict[int, str], mode: str):
|
| 191 |
+
rank_to_storage = {
|
| 192 |
+
rank: SingleProcessFileTensorStorage(data_schema, fpath, mode)
|
| 193 |
+
for rank, fpath in rank_to_fpath.items()
|
| 194 |
+
}
|
| 195 |
+
super().__init__(rank_to_storage) # pyre-ignore[6]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class MultiProcessRamTensorStorage(MultiProcessTensorStorage):
|
| 199 |
+
def __init__(self, data_schema: Dict[str, SizeData], rank_to_buffer: Dict[int, io.BytesIO]):
|
| 200 |
+
rank_to_storage = {
|
| 201 |
+
rank: SingleProcessRamTensorStorage(data_schema, buf)
|
| 202 |
+
for rank, buf in rank_to_buffer.items()
|
| 203 |
+
}
|
| 204 |
+
super().__init__(rank_to_storage) # pyre-ignore[6]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _ram_storage_gather(
|
| 208 |
+
storage: SingleProcessRamTensorStorage, dst_rank: int = 0
|
| 209 |
+
) -> Optional[MultiProcessRamTensorStorage]:
|
| 210 |
+
storage.storage_impl.seek(0, os.SEEK_SET)
|
| 211 |
+
# TODO: overhead, pickling a bytes object, can just pass bytes in a tensor directly
|
| 212 |
+
# see detectron2/utils.comm.py
|
| 213 |
+
data_list = gather(storage.storage_impl.read(), dst=dst_rank)
|
| 214 |
+
if get_rank() != dst_rank:
|
| 215 |
+
return None
|
| 216 |
+
rank_to_buffer = {i: io.BytesIO(data_list[i]) for i in range(len(data_list))}
|
| 217 |
+
multiprocess_storage = MultiProcessRamTensorStorage(storage.data_schema, rank_to_buffer)
|
| 218 |
+
return multiprocess_storage
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _file_storage_gather(
|
| 222 |
+
storage: SingleProcessFileTensorStorage,
|
| 223 |
+
dst_rank: int = 0,
|
| 224 |
+
mode: str = "rb",
|
| 225 |
+
) -> Optional[MultiProcessFileTensorStorage]:
|
| 226 |
+
storage.storage_impl.close()
|
| 227 |
+
fpath_list = gather(storage.fpath, dst=dst_rank)
|
| 228 |
+
if get_rank() != dst_rank:
|
| 229 |
+
return None
|
| 230 |
+
rank_to_fpath = {i: fpath_list[i] for i in range(len(fpath_list))}
|
| 231 |
+
return MultiProcessFileTensorStorage(storage.data_schema, rank_to_fpath, mode)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def storage_gather(
|
| 235 |
+
storage: SingleProcessTensorStorage, dst_rank: int = 0
|
| 236 |
+
) -> Optional[MultiProcessTensorStorage]:
|
| 237 |
+
if isinstance(storage, SingleProcessRamTensorStorage):
|
| 238 |
+
return _ram_storage_gather(storage, dst_rank)
|
| 239 |
+
elif isinstance(storage, SingleProcessFileTensorStorage):
|
| 240 |
+
return _file_storage_gather(storage, dst_rank)
|
| 241 |
+
raise Exception(f"Unsupported storage for gather operation: {storage}")
|